TDengine/test/new_test_framework/utils/common.py
WANG Xu c52c68aa4f
sync: apply remaining build system changes from monorepo (main)
The following commits could not be applied individually due to context
differences between the monorepo and the public repo's build files.
They have been applied as a cumulative diff to ensure the final state
matches the monorepo exactly:

- chore: sync CI files with 3.0 branch to eliminate merge conflicts (rd-public/tsdb!271)
- revert(refactor): dynamically link taosd taosudf taosmqtt against libtaosnative.so to reduce binary size (revert #183) (rd-public/tsdb!282)
- fix(docs): autofix formatting issues across all doc files (rd-public/tsdb!296)
- feat: support -DBUILD_SANITIZER=true on windows for debug build (rd-public/tsdb!291)
- feat(build): build cache, mirror, and sccache optimizations (rd-public/tsdb!326)
- docs: update image for three replica (rd-public/tsdb!324)
- enh: shared storage on windows (rd-public/tsdb!333)
- fix(cmake): convert ext_libs3 from git clone to URL tarball download (rd-public/tsdb!360)
- feat: dual-source deps and comprehensive docs/packaging (cherry-pick to main) (rd-public/tsdb!352)
- fix(cmake): guard DOWNLOAD_EXTRACT_TIMESTAMP for CMake < 3.24 and fix duplicate Cargo.lock entry (rd-public/tsdb!369)
- fix: test case execution failure in pytest.sh (rd-public/tsdb!338)
- enh: built-in compilation support for Python UDF plugins use abi3 (rd-public/tsdb!325)
2026-05-23 14:11:50 +08:00

3535 lines
136 KiB
Python

###################################################################
# Copyright (c) 2016 by TAOS Technologies, Inc.
# All rights reserved.
#
# This file is proprietary and confidential to TAOS Technologies.
# No part of this file may be reproduced, stored, transmitted,
# disclosed or used in any form or by any means other than as
# expressly provided by the written permission from Jianhui Tao
#
###################################################################
# -*- coding: utf-8 -*-
import random
import string
import concurrent
import requests
import time
import socket
import json
import toml
import subprocess
import os
import platform
import tempfile
from .boundary import DataBoundary
import taos
from .log import *
from .sql import *
from .server.dnodes import *
from .common import *
from .constant import *
from .epath import *
from dataclasses import dataclass, field
from decimal import Decimal, InvalidOperation
from typing import List
from datetime import datetime, timedelta
import re
@dataclass
class DataSet:
ts_data: List[int] = field(default_factory=list)
int_data: List[int] = field(default_factory=list)
bint_data: List[int] = field(default_factory=list)
sint_data: List[int] = field(default_factory=list)
tint_data: List[int] = field(default_factory=list)
uint_data: List[int] = field(default_factory=list)
ubint_data: List[int] = field(default_factory=list)
usint_data: List[int] = field(default_factory=list)
utint_data: List[int] = field(default_factory=list)
float_data: List[float] = field(default_factory=list)
double_data: List[float] = field(default_factory=list)
bool_data: List[int] = field(default_factory=list)
vchar_data: List[str] = field(default_factory=list)
nchar_data: List[str] = field(default_factory=list)
def get_order_set(
self,
rows,
int_step: int = 1,
bint_step: int = 1,
sint_step: int = 1,
tint_step: int = 1,
uint_step: int = 1,
ubint_step: int = 1,
usint_step: int = 1,
utint_step: int = 1,
float_step: float = 1,
double_step: float = 1,
bool_start: int = 1,
vchar_prefix: str = "vachar_",
vchar_step: int = 1,
nchar_prefix: str = "nchar_测试_",
nchar_step: int = 1,
ts_step: int = 1,
):
for i in range(rows):
self.int_data.append(int(i * int_step % INT_MAX))
self.bint_data.append(int(i * bint_step % BIGINT_MAX))
self.sint_data.append(int(i * sint_step % SMALLINT_MAX))
self.tint_data.append(int(i * tint_step % TINYINT_MAX))
self.uint_data.append(int(i * uint_step % INT_UN_MAX))
self.ubint_data.append(int(i * ubint_step % BIGINT_UN_MAX))
self.usint_data.append(int(i * usint_step % SMALLINT_UN_MAX))
self.utint_data.append(int(i * utint_step % TINYINT_UN_MAX))
self.float_data.append(float(i * float_step % FLOAT_MAX))
self.double_data.append(float(i * double_step % DOUBLE_MAX))
self.bool_data.append(bool((i + bool_start) % 2))
self.vchar_data.append(f"{vchar_prefix}{i * vchar_step}")
self.nchar_data.append(f"{nchar_prefix}{i * nchar_step}")
self.ts_data.append(
int(datetime.timestamp(datetime.now()) * 1000 - i * ts_step)
)
def get_disorder_set(self, rows, **kwargs):
for k, v in kwargs.items():
int_low = v if k == "int_low" else INT_MIN
int_up = v if k == "int_up" else INT_MAX
bint_low = v if k == "bint_low" else BIGINT_MIN
bint_up = v if k == "bint_up" else BIGINT_MAX
sint_low = v if k == "sint_low" else SMALLINT_MIN
sint_up = v if k == "sint_up" else SMALLINT_MAX
tint_low = v if k == "tint_low" else TINYINT_MIN
tint_up = v if k == "tint_up" else TINYINT_MAX
pass
class TDCom:
def __init__(self):
self.sml_type = None
self.env_setting = None
self.smlChildTableName_value = None
self.defaultJSONStrType_value = None
self.smlTagNullName_value = None
self.default_varchar_length = 6
self.default_nchar_length = 6
self.default_varchar_datatype = "letters"
self.default_nchar_datatype = "letters"
self.default_tagname_prefix = "t"
self.default_colname_prefix = "c"
self.default_stbname_prefix = "stb"
self.default_ctbname_prefix = "ctb"
self.default_tbname_prefix = "tb"
self.default_tag_index_start_num = 1
self.default_column_index_start_num = 1
self.default_stbname_index_start_num = 1
self.default_ctbname_index_start_num = 1
self.default_tbname_index_start_num = 1
self.default_tagts_name = "ts"
self.default_colts_name = "ts"
self.dbname = "test"
self.stb_name = "stb"
self.ctb_name = "ctb"
self.tb_name = "tb"
self.tbname = str()
self.need_tagts = False
self.tag_type_str = ""
self.column_type_str = ""
self.columns_str = None
self.ts_value = None
self.tag_value_list = list()
self.column_value_list = list()
self.full_type_list = [
"tinyint",
"smallint",
"int",
"bigint",
"tinyint unsigned",
"smallint unsigned",
"int unsigned",
"bigint unsigned",
"float",
"double",
"binary",
"nchar",
"bool",
]
self.white_list = [
"statsd",
"node_exporter",
"collectd",
"icinga2",
"tcollector",
"information_schema",
"performance_schema",
]
self.Boundary = DataBoundary()
self.white_list = [
"statsd",
"node_exporter",
"collectd",
"icinga2",
"tcollector",
"information_schema",
"performance_schema",
]
self.case_name = str()
self.des_table_suffix = "_output"
self.stream_suffix = "_stream"
self.range_count = 5
self.default_interval = 5
self.stream_timeout = 12
self.create_stream_sleep = 0.5
self.record_history_ts = str()
self.precision = "ms"
self.date_time = self.genTs(precision=self.precision)[0]
self.subtable = True
self.partition_tbname_alias = "ptn_alias" if self.subtable else ""
self.partition_col_alias = "pcol_alias" if self.subtable else ""
self.partition_tag_alias = "ptag_alias" if self.subtable else ""
self.partition_expression_alias = "pexp_alias" if self.subtable else ""
self.des_table_suffix = "_output"
self.stream_suffix = "_stream"
self.subtable_prefix = "prefix_" if self.subtable else ""
self.subtable_suffix = "_suffix" if self.subtable else ""
self.downsampling_function_list = [
"min(c1)",
"max(c2)",
"sum(c3)",
"first(c4)",
"last(c5)",
"apercentile(c6, 50)",
"avg(c7)",
"count(c8)",
"spread(c1)",
"stddev(c2)",
"hyperloglog(c11)",
"timediff(1, 0, 1h)",
"timezone()",
"to_iso8601(1)",
'to_unixtimestamp("1970-01-01T08:00:00+08:00")',
"min(t1)",
"max(t2)",
"sum(t3)",
"first(t4)",
"last(t5)",
"apercentile(t6, 50)",
"avg(t7)",
"count(t8)",
"spread(t1)",
"stddev(t2)",
"hyperloglog(t11)",
]
self.stb_output_select_str = ",".join(
list(map(lambda x: f"`{x}`", self.downsampling_function_list))
)
self.tb_output_select_str = ",".join(
list(map(lambda x: f"`{x}`", self.downsampling_function_list[0:15]))
)
self.stb_source_select_str = ",".join(self.downsampling_function_list)
self.tb_source_select_str = ",".join(self.downsampling_function_list[0:15])
self.fill_function_list = [
"min(c1)",
"max(c2)",
"sum(c3)",
"apercentile(c6, 50)",
"avg(c7)",
"count(c8)",
"spread(c1)",
"stddev(c2)",
"hyperloglog(c11)",
"timediff(1, 0, 1h)",
"timezone()",
"to_iso8601(1)",
'to_unixtimestamp("1970-01-01T08:00:00+08:00")',
"min(t1)",
"max(t2)",
"sum(t3)",
"first(t4)",
"last(t5)",
"apercentile(t6, 50)",
"avg(t7)",
"count(t8)",
"spread(t1)",
"stddev(t2)",
"hyperloglog(t11)",
]
self.fill_stb_output_select_str = ",".join(
list(map(lambda x: f"`{x}`", self.fill_function_list))
)
self.fill_stb_source_select_str = ",".join(self.fill_function_list)
self.fill_tb_output_select_str = ",".join(
list(map(lambda x: f"`{x}`", self.fill_function_list[0:13]))
)
self.fill_tb_source_select_str = ",".join(self.fill_function_list[0:13])
self.ext_tb_source_select_str = ",".join(self.downsampling_function_list[0:13])
self.stream_case_when_tbname = "tbname"
self.update = True
self.disorder = True
if self.disorder:
self.update = False
self.partition_by_downsampling_function_list = [
"min(c1)",
"max(c2)",
"sum(c3)",
"first(c4)",
"last(c5)",
"count(c8)",
"spread(c1)",
"stddev(c2)",
"hyperloglog(c11)",
"min(t1)",
"max(t2)",
"sum(t3)",
"first(t4)",
"last(t5)",
"count(t8)",
"spread(t1)",
"stddev(t2)",
]
self.stb_data_filter_sql = (
f'ts >= {self.date_time}+1s and c1 = 1 or c2 > 1 and c3 != 4 or c4 <= 3 and c9 <> 0 or c10 is not Null or c11 is Null or \
c12 between "na" and "nchar4" and c11 not between "bi" and "binary" and c12 match "nchar[19]" and c12 nmatch "nchar[25]" or c13 = True or \
c5 in (1, 2, 3) or c6 not in (6, 7) and c12 like "nch%" and c11 not like "bina_" and c6 < 10 or c12 is Null or c8 >= 4 and t1 = 1 or t2 > 1 \
and t3 != 4 or c4 <= 3 and t9 <> 0 or t10 is not Null or t11 is Null or t12 between "na" and "nchar4" and t11 not between "bi" and "binary" \
or t12 match "nchar[19]" or t12 nmatch "nchar[25]" or t13 = True or t5 in (1, 2, 3) or t6 not in (6, 7) and t12 like "nch%" \
and t11 not like "bina_" and t6 <= 10 or t12 is Null or t8 >= 4'
)
self.tb_data_filter_sql = self.stb_data_filter_sql.partition(" and t1")[0]
self.filter_source_select_elm = "*"
self.stb_filter_des_select_elm = "ts, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, t1, t2, t3, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13"
self.partitial_stb_filter_des_select_elm = ",".join(
self.stb_filter_des_select_elm.split(",")[:3]
)
self.exchange_stb_filter_des_select_elm = ",".join(
[
self.stb_filter_des_select_elm.split(",")[0],
self.stb_filter_des_select_elm.split(",")[2],
self.stb_filter_des_select_elm.split(",")[1],
]
)
self.partitial_ext_tb_source_select_str = ",".join(
self.downsampling_function_list[0:2]
)
self.tb_filter_des_select_elm = self.stb_filter_des_select_elm.partition(
", t1"
)[0]
self.tag_filter_des_select_elm = self.stb_filter_des_select_elm.partition(
"c13, "
)[2]
self.partition_by_stb_output_select_str = ",".join(
list(map(lambda x: f"`{x}`", self.partition_by_downsampling_function_list))
)
self.partition_by_stb_source_select_str = ",".join(
self.partition_by_downsampling_function_list
)
self.exchange_tag_filter_des_select_elm = ",".join(
[
self.stb_filter_des_select_elm.partition("c13, ")[2].split(",")[0],
self.stb_filter_des_select_elm.partition("c13, ")[2].split(",")[2],
self.stb_filter_des_select_elm.partition("c13, ")[2].split(",")[1],
]
)
self.partitial_tag_filter_des_select_elm = ",".join(
self.stb_filter_des_select_elm.partition("c13, ")[2].split(",")[:3]
)
self.partitial_tag_stb_filter_des_select_elm = "ts, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11, c12, c13, t1, t3, t2, t4, t5, t6, t7, t8, t9, t10, t11, t12, t13"
self.cast_tag_filter_des_select_elm = "t5,t11,t13"
self.cast_tag_stb_filter_des_select_elm = "ts, t1, t2, t3, t4, cast(t1 as TINYINT UNSIGNED), t6, t7, t8, t9, t10, cast(t2 as varchar(256)), t12, cast(t3 as bool)"
self.tag_count = len(self.tag_filter_des_select_elm.split(","))
self.state_window_range = list()
self.custom_col_val = 0
self.part_val_list = [1, 2]
self.taos_bin_path = "/usr/bin"
self.taos_cfg_path = "/etc/taos"
self.work_dir = os.path.join(
os.path.dirname(
os.path.dirname(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
),
"sim",
)
# def init(self, conn, logSql):
# # tdSql.init(conn.cursor(), logSql)
def init(self, taos_bin_path, taos_cfg_path, work_dir):
self.taos_bin_path = taos_bin_path
self.taos_cfg_path = taos_cfg_path
self.work_dir = work_dir
def preDefine(self):
header = {"Authorization": "Basic cm9vdDp0YW9zZGF0YQ=="}
sql_url = "http://127.0.0.1:6041/rest/sql"
sqlt_url = "http://127.0.0.1:6041/rest/sqlt"
sqlutc_url = "http://127.0.0.1:6041/rest/sqlutc"
influx_url = "http://127.0.0.1:6041/influxdb/v1/write"
telnet_url = "http://127.0.0.1:6041/opentsdb/v1/put/telnet"
return header, sql_url, sqlt_url, sqlutc_url, influx_url, telnet_url
def genTcpParam(self):
MaxBytes = 1024 * 1024
host = "127.0.0.1"
port = 6046
return MaxBytes, host, port
def tcpClient(self, input):
MaxBytes = tdCom.genTcpParam()[0]
host = tdCom.genTcpParam()[1]
port = tdCom.genTcpParam()[2]
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect((host, port))
sock.send(input.encode())
sock.close()
def restApiPost(self, sql):
requests.post(
self.preDefine()[1], sql.encode("utf-8"), headers=self.preDefine()[0]
)
def createDb(self, dbname="test", db_update_tag=0, api_type="taosc"):
if api_type == "taosc":
if db_update_tag == 0:
tdSql.execute(f"drop database if exists {dbname}")
tdSql.execute(f"create database if not exists {dbname} precision 'us'")
else:
tdSql.execute(f"drop database if exists {dbname}")
tdSql.execute(
f"create database if not exists {dbname} precision 'us' update 1"
)
elif api_type == "restful":
if db_update_tag == 0:
self.restApiPost(f"drop database if exists {dbname}")
self.restApiPost(
f"create database if not exists {dbname} precision 'us'"
)
else:
self.restApiPost(f"drop database if exists {dbname}")
self.restApiPost(
f"create database if not exists {dbname} precision 'us' update 1"
)
tdSql.execute(f"use {dbname}")
def genUrl(self, url_type, dbname, precision):
if url_type == "influxdb":
if precision is None:
url = self.preDefine()[4] + "?" + "db=" + dbname
else:
url = (
self.preDefine()[4]
+ "?"
+ "db="
+ dbname
+ "&precision="
+ precision
)
elif url_type == "telnet":
url = self.preDefine()[5] + "/" + dbname
else:
url = self.preDefine()[1]
return url
def schemalessApiPost(
self, sql, url_type="influxdb", dbname="test", precision=None
):
if url_type == "influxdb":
url = self.genUrl(url_type, dbname, precision)
elif url_type == "telnet":
url = self.genUrl(url_type, dbname, precision)
res = requests.post(url, sql.encode("utf-8"), headers=self.preDefine()[0])
return res
def cleanTb(self, type="taosc", dbname="db"):
"""
type is taosc or restful
"""
query_sql = f"show {dbname}.stables"
res_row_list = tdSql.query(query_sql, True)
stb_list = map(lambda x: x[0], res_row_list)
for stb in stb_list:
if type == "taosc":
tdSql.execute(f"drop table if exists {dbname}.`{stb}`")
if not stb[0].isdigit():
tdSql.execute(f"drop table if exists {dbname}.{stb}")
elif type == "restful":
self.restApiPost(f"drop table if exists {dbname}.`{stb}`")
if not stb[0].isdigit():
self.restApiPost(f"drop table if exists {dbname}.{stb}")
def dateToTs(self, datetime_input):
return int(time.mktime(time.strptime(datetime_input, "%Y-%m-%d %H:%M:%S.%f")))
def genTs(self, precision="ms", ts="", protype="taosc", ns_tag=None):
"""
protype = "taosc" or "restful"
gen ts and datetime
"""
if precision == "ns":
if ts == "" or ts is None:
ts = time.time_ns()
else:
ts = ts
if ns_tag is None:
dt = ts
else:
dt = datetime.fromtimestamp(ts // 1000000000)
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000000000)).zfill(9)
)
if protype == "restful":
dt = datetime.fromtimestamp(ts // 1000000000)
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000000000)).zfill(9)
)
else:
if ts == "" or ts is None:
ts = time.time()
else:
ts = ts
if precision == "ms" or precision is None:
ts = int(round(ts * 1000))
dt = datetime.fromtimestamp(ts // 1000)
if protype == "taosc":
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000)).zfill(3)
+ "000"
)
elif protype == "restful":
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000)).zfill(3)
)
else:
pass
elif precision == "us":
ts = int(round(ts * 1000000))
dt = datetime.fromtimestamp(ts // 1000000)
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000000)).zfill(6)
)
return ts, dt
def get_long_name(self, length=10, mode="letters"):
"""
generate long name
mode could be numbers/letters/letters_mixed/mixed
"""
if mode == "numbers":
population = string.digits
elif mode == "letters":
population = string.ascii_letters.lower()
elif mode == "letters_mixed":
population = string.ascii_letters.upper() + string.ascii_letters.lower()
else:
population = string.ascii_letters.lower() + string.digits
return "".join(random.choices(population, k=length))
def getLongName(self, len, mode="mixed"):
"""
generate long name
mode could be numbers/letters/letters_mixed/mixed
"""
if mode == "numbers":
chars = "".join(random.choice(string.digits) for i in range(len))
elif mode == "letters":
chars = "".join(
random.choice(string.ascii_letters.lower()) for i in range(len)
)
elif mode == "letters_mixed":
chars = "".join(
random.choice(
string.ascii_letters.upper() + string.ascii_letters.lower()
)
for i in range(len)
)
else:
chars = "".join(
random.choice(string.ascii_letters.lower() + string.digits)
for i in range(len)
)
return chars
def restartTaosd(self, index=1, db_name="db"):
tdDnodes.stop(index)
tdDnodes.startWithoutSleep(index)
tdSql.execute(f"use {db_name}")
def typeof(self, variate):
v_type = None
if type(variate) is int:
v_type = "int"
elif type(variate) is str:
v_type = "str"
elif type(variate) is float:
v_type = "float"
elif type(variate) is bool:
v_type = "bool"
elif type(variate) is list:
v_type = "list"
elif type(variate) is tuple:
v_type = "tuple"
elif type(variate) is dict:
v_type = "dict"
elif type(variate) is set:
v_type = "set"
return v_type
def splitNumLetter(self, input_mix_str):
nums, letters = "", ""
for i in input_mix_str:
if i.isdigit():
nums += i
elif i.isspace():
pass
else:
letters += i
return nums, letters
def smlPass(self, func):
smlChildTableName = "no"
def wrapper(*args):
# if tdSql.getVariable("smlChildTableName")[0].upper() == "ID":
if smlChildTableName.upper() == "ID":
return func(*args)
else:
pass
return wrapper
def close(self):
self.cursor.close()
########################################################################################################################################
# new common API
########################################################################################################################################
def create_database(self, td_sql, dbName="test", dropFlag=1, **kwargs):
if dropFlag == 1:
td_sql.execute(f"drop database if exists {dbName}")
"""
vgroups replica precision strict wal fsync comp cachelast single_stable buffer pagesize pages minrows maxrows duration keep retentions
"""
sqlString = f"create database if not exists {dbName}"
dbParams = ""
if len(kwargs) > 0:
for param, value in kwargs.items():
if param == "precision":
dbParams += f'{param} "{value}" '
else:
dbParams += f"{param} {value} "
if dbParams:
# must have a space before dbParams
sqlString += " " + dbParams.strip()
tdLog.debug(f"create db sql: {sqlString}")
td_sql.execute(sqlString)
create_db_status = self.waitTransactionZeroWithTdsql(td_sql)
if not create_db_status:
tdLog.exit("Transaction did not reach zero, aborting test.")
else:
tdLog.debug(f"complete to create database {dbName}")
return
# def create_stable(self,td_sql, dbName,stbName,column_elm_list=None, tag_elm_list=None):
# colSchema = ''
# for i in range(columnDict['int']):
# colSchema += ', c%d int'%i
# tagSchema = ''
# for i in range(tagDict['int']):
# if i > 0:
# tagSchema += ','
# tagSchema += 't%d int'%i
# td_sql.execute("create table if not exists %s.%s (ts timestamp %s) tags(%s)"%(dbName, stbName, colSchema, tagSchema))
# tdLog.debug("complete to create %s.%s" %(dbName, stbName))
# return
# def create_ctables(self,td_sql, dbName,stbName,ctbNum,tagDict):
# td_sql.execute("use %s" %dbName)
# tagsValues = ''
# for i in range(tagDict['int']):
# if i > 0:
# tagsValues += ','
# tagsValues += '%d'%i
# pre_create = "create table"
# sql = pre_create
# #tdLog.debug("doing create one stable %s and %d child table in %s ..." %(stbname, count ,dbname))
# for i in range(ctbNum):
# sql += " %s_%d using %s tags(%s)"%(stbName,i,stbName,tagsValues)
# if (i > 0) and (i%100 == 0):
# td_sql.execute(sql)
# sql = pre_create
# if sql != pre_create:
# td_sql.execute(sql)
# tdLog.debug("complete to create %d child tables in %s.%s" %(ctbNum, dbName, stbName))
# return
# def insert_data(self,td_sql,dbName,stbName,ctbNum,rowsPerTbl,batchNum,startTs=0):
# tdLog.debug("start to insert data ............")
# td_sql.execute("use %s" %dbName)
# pre_insert = "insert into "
# sql = pre_insert
# if startTs == 0:
# t = time.time()
# startTs = int(round(t * 1000))
# #tdLog.debug("doing insert data into stable:%s rows:%d ..."%(stbName, allRows))
# for i in range(ctbNum):
# sql += " %s_%d values "%(stbName,i)
# for j in range(rowsPerTbl):
# sql += "(%d, %d, %d)"%(startTs + j, j, j)
# if (j > 0) and ((j%batchNum == 0) or (j == rowsPerTbl - 1)):
# td_sql.execute(sql)
# if j < rowsPerTbl - 1:
# sql = "insert into %s_%d values " %(stbName,i)
# else:
# sql = "insert into "
# #end sql
# if sql != pre_insert:
# #print("insert sql:%s"%sql)
# td_sql.execute(sql)
# tdLog.debug("insert data ............ [OK]")
# return
def getBuildPath(self):
selfPath = os.path.dirname(os.path.realpath(__file__))
if "taos-community" in selfPath:
# tsdb repo layout: /mnt/tsdb/source/taos-community/test/...
projPath = selfPath[: selfPath.find("source/taos-community")]
elif "community" in selfPath:
projPath = selfPath[: selfPath.find("community")]
elif "TDengine" in selfPath:
projPath = selfPath[: selfPath.find("TDengine") + len("TDengine")]
else:
projPath = selfPath[: selfPath.find("test")]
for root, dirs, files in os.walk(projPath):
if ".git" in root:
continue
if "taosd" in files or "taosd.exe" in files:
rootRealPath = os.path.dirname(os.path.realpath(root))
if "packaging" not in rootRealPath:
buildPath = root[: len(root) - len("/build/bin")]
break
# if platform.system().lower() == 'windows':
# win_sep = "\\"
# buildPath = buildPath.replace(win_sep,'/')
return buildPath
def getTaosdPath(self, dnodeID="dnode1"):
return os.path.join(self.work_dir, dnodeID)
def getClientCfgPath(self):
return os.path.join(self.work_dir, "psim", "cfg")
# buildPath = self.getBuildPath()
# if (buildPath == ""):
# tdLog.exit("taosd not found!")
# else:
# tdLog.info("taosd found in %s" % buildPath)
# cfgPath = buildPath + "/../sim/psim/cfg"
# tdLog.info("cfgPath: %s" % cfgPath)
# return cfgPath
def newcon(
self,
host="localhost",
port=6030,
user="root",
password="taosdata",
database=None,
):
con = taos.connect(
host=host, user=user, password=password, port=port, database=database
)
# print(con)
return con
def newcur(
self,
host="localhost",
port=6030,
user="root",
password="taosdata",
database=None,
):
cfgPath = self.getClientCfgPath()
con = taos.connect(
host=host,
user=user,
password=password,
config=cfgPath,
port=port,
database=database,
)
cur = con.cursor()
# print(cur)
return cur
def newTdSql(
self,
host="localhost",
port=6030,
user="root",
password="taosdata",
database=None,
):
newTdSql = TDSql()
cur = self.newcur(
host=host, port=port, user=user, password=password, database=database
)
newTdSql.init(cur, False)
return newTdSql
def newcurWithTimezone(
self, timezone, host="localhost", port=6030, user="root", password="taosdata"
):
cfgPath = self.getClientCfgPath()
con = taos.connect(
host=host,
user=user,
password=password,
config=cfgPath,
port=port,
timezone=timezone,
)
cur = con.cursor()
# print(cur)
return cur
def newTdSqlWithTimezone(
self, timezone, host="localhost", port=6030, user="root", password="taosdata"
):
newTdSql = TDSql()
cur = self.newcurWithTimezone(
host=host, port=port, user=user, password=password, timezone=timezone
)
newTdSql.init(cur, False)
return newTdSql
################################################################################################################
# port from the common.py of new test frame
################################################################################################################
def gen_default_tag_str(self):
default_tag_str = ""
for tag_type in self.full_type_list:
if tag_type.lower() not in ["varchar", "binary", "nchar"]:
default_tag_str += f" {self.default_tagname_prefix}{self.default_tag_index_start_num} {tag_type},"
else:
if tag_type.lower() in ["varchar", "binary"]:
default_tag_str += f" {self.default_tagname_prefix}{self.default_tag_index_start_num} {tag_type}({self.default_varchar_length}),"
else:
default_tag_str += f" {self.default_tagname_prefix}{self.default_tag_index_start_num} {tag_type}({self.default_nchar_length}),"
self.default_tag_index_start_num += 1
if self.need_tagts:
default_tag_str = self.default_tagts_name + " timestamp," + default_tag_str
return default_tag_str[:-1].lstrip()
def gen_default_column_str(self):
self.default_column_index_start_num = 1
default_column_str = ""
for col_type in self.full_type_list:
if col_type.lower() not in ["varchar", "binary", "nchar"]:
default_column_str += f" {self.default_colname_prefix}{self.default_column_index_start_num} {col_type},"
else:
if col_type.lower() in ["varchar", "binary"]:
default_column_str += f" {self.default_colname_prefix}{self.default_column_index_start_num} {col_type}({self.default_varchar_length}),"
else:
default_column_str += f" {self.default_colname_prefix}{self.default_column_index_start_num} {col_type}({self.default_nchar_length}),"
self.default_column_index_start_num += 1
default_column_str = (
self.default_colts_name + " timestamp," + default_column_str
)
return default_column_str[:-1].lstrip()
def gen_tag_type_str(self, tagname_prefix, tag_elm_list):
tag_index_start_num = 1
tag_type_str = ""
if tag_elm_list is None:
tag_type_str = self.gen_default_tag_str()
else:
for tag_elm in tag_elm_list:
if "count" in tag_elm:
total_count = int(tag_elm["count"])
else:
total_count = 1
if total_count > 0:
for _ in range(total_count):
tag_type_str += (
f"{tagname_prefix}{tag_index_start_num} {tag_elm['type']}, "
)
if tag_elm["type"] in ["varchar", "binary", "nchar"]:
tag_type_str = (
tag_type_str.rstrip()[:-1] + f"({tag_elm['len']}), "
)
tag_index_start_num += 1
else:
continue
tag_type_str = tag_type_str.rstrip()[:-1]
return tag_type_str
def gen_column_type_str(self, colname_prefix, column_elm_list):
column_index_start_num = 1
column_type_str = ""
if column_elm_list is None:
column_type_str = self.gen_default_column_str()
else:
for column_elm in column_elm_list:
if "count" in column_elm:
total_count = int(column_elm["count"])
else:
total_count = 1
if total_count > 0:
for _ in range(total_count):
column_type_str += f"{colname_prefix}{column_index_start_num} {column_elm['type']}, "
if column_elm["type"] in ["varchar", "binary", "nchar"]:
column_type_str = (
column_type_str.rstrip()[:-1]
+ f"({column_elm['len']}), "
)
column_index_start_num += 1
else:
continue
column_type_str = (
self.default_colts_name + " timestamp, " + column_type_str.rstrip()[:-1]
)
return column_type_str
def gen_random_type_value(
self, type_name, binary_length, binary_type, nchar_length, nchar_type
):
if type_name.lower() == "tinyint":
return random.randint(
self.Boundary.TINYINT_BOUNDARY[0], self.Boundary.TINYINT_BOUNDARY[1]
)
elif type_name.lower() == "smallint":
return random.randint(
self.Boundary.SMALLINT_BOUNDARY[0], self.Boundary.SMALLINT_BOUNDARY[1]
)
elif type_name.lower() == "int":
return random.randint(
self.Boundary.INT_BOUNDARY[0], self.Boundary.INT_BOUNDARY[1]
)
elif type_name.lower() == "bigint":
return random.randint(
self.Boundary.BIGINT_BOUNDARY[0], self.Boundary.BIGINT_BOUNDARY[1]
)
elif type_name.lower() == "tinyint unsigned":
return random.randint(
self.Boundary.UTINYINT_BOUNDARY[0], self.Boundary.UTINYINT_BOUNDARY[1]
)
elif type_name.lower() == "smallint unsigned":
return random.randint(
self.Boundary.USMALLINT_BOUNDARY[0], self.Boundary.USMALLINT_BOUNDARY[1]
)
elif type_name.lower() == "int unsigned":
return random.randint(
self.Boundary.UINT_BOUNDARY[0], self.Boundary.UINT_BOUNDARY[1]
)
elif type_name.lower() == "bigint unsigned":
return random.randint(
self.Boundary.UBIGINT_BOUNDARY[0], self.Boundary.UBIGINT_BOUNDARY[1]
)
elif type_name.lower() == "float":
return random.uniform(
self.Boundary.FLOAT_BOUNDARY[0], self.Boundary.FLOAT_BOUNDARY[1]
)
elif type_name.lower() == "double":
return random.uniform(
self.Boundary.FLOAT_BOUNDARY[0], self.Boundary.FLOAT_BOUNDARY[1]
)
elif type_name.lower() == "binary":
return f"{self.get_long_name(binary_length, binary_type)}"
elif type_name.lower() == "varchar":
return self.get_long_name(binary_length, binary_type)
elif type_name.lower() == "nchar":
return self.get_long_name(nchar_length, nchar_type)
elif type_name.lower() == "bool":
return random.choice(self.Boundary.BOOL_BOUNDARY)
elif type_name.lower() == "timestamp":
return self.genTs()[0]
else:
pass
def gen_tag_value_list(self, tag_elm_list):
tag_value_list = list()
if tag_elm_list is None:
tag_value_list = list(
map(
lambda i: self.gen_random_type_value(
i,
self.default_varchar_length,
self.default_varchar_datatype,
self.default_nchar_length,
self.default_nchar_datatype,
),
self.full_type_list,
)
)
else:
for tag_elm in tag_elm_list:
if "count" in tag_elm:
total_count = int(tag_elm["count"])
else:
total_count = 1
if total_count > 0:
for _ in range(total_count):
if tag_elm["type"] in ["varchar", "binary", "nchar"]:
tag_value_list.append(
self.gen_random_type_value(
tag_elm["type"],
tag_elm["len"],
self.default_varchar_datatype,
tag_elm["len"],
self.default_nchar_datatype,
)
)
else:
tag_value_list.append(
self.gen_random_type_value(
tag_elm["type"], "", "", "", ""
)
)
else:
continue
return tag_value_list
def gen_column_value_list(self, column_elm_list, ts_value=None):
if ts_value is None:
ts_value = self.genTs()[0]
column_value_list = list()
column_value_list.append(ts_value)
if column_elm_list is None:
column_value_list = list(
map(
lambda i: self.gen_random_type_value(
i,
self.default_varchar_length,
self.default_varchar_datatype,
self.default_nchar_length,
self.default_nchar_datatype,
),
self.full_type_list,
)
)
else:
for column_elm in column_elm_list:
if "count" in column_elm:
total_count = int(column_elm["count"])
else:
total_count = 1
if total_count > 0:
for _ in range(total_count):
if column_elm["type"] in ["varchar", "binary", "nchar"]:
column_value_list.append(
self.gen_random_type_value(
column_elm["type"],
column_elm["len"],
self.default_varchar_datatype,
column_elm["len"],
self.default_nchar_datatype,
)
)
else:
column_value_list.append(
self.gen_random_type_value(
column_elm["type"], "", "", "", ""
)
)
else:
continue
# column_value_list = [self.ts_value] + self.column_value_list
return column_value_list
def create_stable(
self,
td_sql,
dbname=None,
stbname="stb",
column_elm_list=None,
tag_elm_list=None,
count=1,
default_stbname_prefix="stb",
**kwargs,
):
colname_prefix = "c"
tagname_prefix = "t"
stbname_index_start_num = 1
stb_params = ""
if len(kwargs) > 0:
for param, value in kwargs.items():
stb_params += f'{param} "{value}" '
column_type_str = self.gen_column_type_str(colname_prefix, column_elm_list)
tag_type_str = self.gen_tag_type_str(tagname_prefix, tag_elm_list)
if int(count) <= 1:
create_stable_sql = f"create table {dbname}.{stbname} ({column_type_str}) tags ({tag_type_str}) {stb_params};"
tdLog.info("create stb sql: %s" % create_stable_sql)
td_sql.execute(create_stable_sql)
else:
for _ in range(count):
create_stable_sql = f"create table {dbname}.{default_stbname_prefix}{stbname_index_start_num} ({column_type_str}) tags ({tag_type_str}) {stb_params};"
stbname_index_start_num += 1
td_sql.execute(create_stable_sql)
def create_ctable(
self,
td_sql,
dbname=None,
stbname=None,
tag_elm_list=None,
count=1,
default_ctbname_prefix="ctb",
**kwargs,
):
ctbname_index_start_num = 0
ctb_params = ""
if len(kwargs) > 0:
for param, value in kwargs.items():
ctb_params += f'{param} "{value}" '
tag_value_list = self.gen_tag_value_list(tag_elm_list)
tag_value_str = ""
# tag_value_str = ", ".join(str(v) for v in self.tag_value_list)
for tag_value in tag_value_list:
if isinstance(tag_value, str):
tag_value_str += f'"{tag_value}", '
else:
tag_value_str += f"{tag_value}, "
tag_value_str = tag_value_str.rstrip()[:-1]
if int(count) <= 1:
create_ctable_sql = f"create table {dbname}.{default_ctbname_prefix}{ctbname_index_start_num} using {dbname}.{stbname} tags ({tag_value_str}) {ctb_params};"
td_sql.execute(create_ctable_sql)
else:
for _ in range(count):
create_ctable_sql = f"create table {dbname}.{default_ctbname_prefix}{ctbname_index_start_num} using {dbname}.{stbname} tags ({tag_value_str}) {ctb_params};"
ctbname_index_start_num += 1
tdLog.info("create ctb sql: %s" % create_ctable_sql)
td_sql.execute(create_ctable_sql)
def create_table(
self, td_sql, dbname=None, tbname="ntb", column_elm_list=None, count=1, **kwargs
):
tbname_index_start_num = 1
tbname_prefix = "ntb"
tb_params = ""
if len(kwargs) > 0:
for param, value in kwargs.items():
tb_params += f'{param} "{value}" '
column_type_str = self.gen_column_type_str(tbname_prefix, column_elm_list)
if int(count) <= 1:
create_table_sql = (
f"create table {dbname}.{tbname} ({column_type_str}) {tb_params};"
)
td_sql.execute(create_table_sql)
else:
for _ in range(count):
create_table_sql = f"create table {dbname}.{tbname_prefix}{tbname_index_start_num} ({column_type_str}) {tb_params};"
tbname_index_start_num += 1
td_sql.execute(create_table_sql)
def insert_rows(
self,
td_sql,
dbname=None,
tbname=None,
column_ele_list=None,
start_ts_value=None,
count=1,
):
if start_ts_value is None:
start_ts_value = self.genTs()[0]
column_value_list = self.gen_column_value_list(column_ele_list, start_ts_value)
# column_value_str = ", ".join(str(v) for v in self.column_value_list)
column_value_str = ""
for column_value in column_value_list:
if isinstance(column_value, str):
column_value_str += f'"{column_value}", '
else:
column_value_str += f"{column_value}, "
column_value_str = column_value_str.rstrip()[:-1]
if int(count) <= 1:
insert_sql = f"insert into {self.tb_name} values ({column_value_str});"
td_sql.execute(insert_sql)
else:
for num in range(count):
column_value_list = self.gen_column_value_list(
column_ele_list, f"{start_ts_value}+{num}s"
)
# column_value_str = ", ".join(str(v) for v in column_value_list)
column_value_str = ""
idx = 0
for column_value in column_value_list:
if isinstance(column_value, str) and idx != 0:
column_value_str += f'"{column_value}", '
else:
column_value_str += f"{column_value}, "
idx += 1
column_value_str = column_value_str.rstrip()[:-1]
insert_sql = (
f"insert into {dbname}.{tbname} values ({column_value_str});"
)
td_sql.execute(insert_sql)
def getOneRow(self, location, containElm):
res_list = list()
if 0 <= location < tdSql.queryRows:
for row in tdSql.queryResult:
if row[location] == containElm:
res_list.append(row)
return res_list
else:
tdLog.exit(
f"getOneRow out of range: row_index={location} row_count={self.query_row}"
)
def killProcessor(self, processorName):
if platform.system().lower() == "windows":
os.system("TASKKILL /F /IM %s.exe" % processorName)
else:
os.system("unset LD_PRELOAD; pkill -9 %s " % processorName)
def kill_signal_process(self, signal=15, processor_name: str = "taosd"):
if platform.system().lower() == "windows":
os.system(f"TASKKILL /F /IM {processor_name}.exe")
else:
command = f"unset LD_PRELOAD; sudo pkill -f -{signal} '{processor_name}'"
tdLog.debug(f"command: {command}")
os.system(command)
def gen_tag_col_str(self, gen_type, data_type, count):
"""
gen multi tags or cols by gen_type
"""
return ",".join(map(lambda i: f"{gen_type}{i} {data_type}", range(count)))
# old_stream
def create_old_stream(
self,
stream_name,
des_table,
source_sql,
trigger_mode=None,
watermark=None,
max_delay=None,
ignore_expired=None,
ignore_update=None,
subtable_value=None,
fill_value=None,
fill_history_value=None,
stb_field_name_value=None,
tag_value=None,
use_exist_stb=False,
use_except=False,
):
"""create_stream
Args:
stream_name (str): stream_name
des_table (str): target stable
source_sql (str): stream sql
trigger_mode (str, optional): at_once/window_close/max_delay. Defaults to None.
watermark (str, optional): watermark time. Defaults to None.
max_delay (str, optional): max_delay time. Defaults to None.
ignore_expired (int, optional): ignore expired data. Defaults to None.
ignore_update (int, optional): ignore update data. Defaults to None.
subtable_value (str, optional): subtable. Defaults to None.
fill_value (str, optional): fill. Defaults to None.
fill_history_value (int, optional): 0/1. Defaults to None.
stb_field_name_value (str, optional): existed stb. Defaults to None.
tag_value (str, optional): custom tag. Defaults to None.
use_exist_stb (bool, optional): use existed stb tag. Defaults to False.
use_except (bool, optional): Exception tag. Defaults to False.
Returns:
str: stream
"""
if subtable_value is None:
subtable = ""
else:
subtable = f"subtable({subtable_value})"
if fill_value is None:
fill = ""
else:
fill = f"fill({fill_value})"
if fill_history_value is None:
fill_history = ""
else:
fill_history = f"fill_history {fill_history_value}"
if use_exist_stb:
if stb_field_name_value is None:
stb_field_name = ""
else:
stb_field_name = f"({stb_field_name_value})"
if tag_value is None:
tags = ""
else:
tags = f"tags({tag_value})"
else:
stb_field_name = ""
tags = ""
if trigger_mode is None:
stream_options = ""
if watermark is not None:
stream_options = f"watermark {watermark}"
if ignore_expired:
stream_options += f" ignore expired {ignore_expired}"
else:
stream_options += f" ignore expired 0"
if ignore_update:
stream_options += f" ignore update {ignore_update}"
else:
stream_options += f" ignore update 0"
if not use_except:
tdSql.execute(
f"create stream if not exists {stream_name} trigger at_once {stream_options} {fill_history} into {des_table} {subtable} as {source_sql} {fill};",
queryTimes=3,
)
time.sleep(self.create_stream_sleep)
return None
else:
return f"create stream if not exists {stream_name} {stream_options} {fill_history} into {des_table} {subtable} as {source_sql} {fill};"
else:
if watermark is None:
if trigger_mode == "max_delay":
stream_options = f"trigger {trigger_mode} {max_delay}"
else:
stream_options = f"trigger {trigger_mode}"
else:
if trigger_mode == "max_delay":
stream_options = (
f"trigger {trigger_mode} {max_delay} watermark {watermark}"
)
else:
stream_options = f"trigger {trigger_mode} watermark {watermark}"
if ignore_expired:
stream_options += f" ignore expired {ignore_expired}"
else:
stream_options += f" ignore expired 0"
if ignore_update:
stream_options += f" ignore update {ignore_update}"
else:
stream_options += f" ignore update 0"
if not use_except:
tdSql.execute(
f"create stream if not exists {stream_name} {stream_options} {fill_history} into {des_table}{stb_field_name} {tags} {subtable} as {source_sql} {fill};",
queryTimes=3,
)
time.sleep(self.create_stream_sleep)
return None
else:
return f"create stream if not exists {stream_name} {stream_options} {fill_history} into {des_table}{stb_field_name} {tags} {subtable} as {source_sql} {fill};"
def create_stream(
self,
stream_name,
des_table=None,
source_sql=None,
trigger_table=None,
trigger_type=None,
from_table=None,
partition_by=None,
stream_options=None,
notification_definition=None,
output_subtable=None,
columns=None,
tags=None,
if_not_exists=True,
db_name=None,
use_except=False,
):
"""create_stream with new syntax
Args:
stream_name (str): stream_name
des_table (str, optional): target table. Defaults to None.
source_sql (str, optional): subquery. Defaults to None.
trigger_table (str, optional): trigger table name. Defaults to None.
trigger_type (str, optional): SESSION/STATE_WINDOW/INTERVAL/EVENT_WINDOW/COUNT_WINDOW/PERIOD. Defaults to None.
from_table (str, optional): source table name. Defaults to None.
partition_by (str, optional): partition columns. Defaults to None.
stream_options (str, optional): stream options. Defaults to None.
notification_definition (str, optional): notification settings. Defaults to None.
output_subtable (str, optional): subtable expression. Defaults to None.
columns (str, optional): column definitions. Defaults to None.
tags (str, optional): tag definitions. Defaults to None.
if_not_exists (bool, optional): if not exists flag. Defaults to True.
db_name (str, optional): database name. Defaults to None.
use_except (bool, optional): Exception tag. Defaults to False.
Returns:
str: stream SQL if use_except=True, None otherwise
"""
# Build stream name with database prefix if provided
full_stream_name = f"{db_name}.{stream_name}" if db_name else stream_name
# Build IF NOT EXISTS clause
if_not_exists_clause = "IF NOT EXISTS" if if_not_exists else ""
# Build INTO clause
into_clause = f"INTO {des_table}" if des_table else ""
# Build OUTPUT_SUBTABLE clause
output_subtable_clause = (
f"OUTPUT_SUBTABLE({output_subtable})" if output_subtable else ""
)
# Build columns clause
columns_clause = f"({columns})" if columns else ""
# Build tags clause
tags_clause = f"TAGS ({tags})" if tags else ""
# Add trigger_table
trigger_table = f" from {trigger_table} " if trigger_table else ""
# Add trigger_type
trigger_type = f" {trigger_type} " if trigger_type else ""
# Build options section
options_parts = []
# Add FROM clause
if from_table:
options_parts.append(f"FROM {from_table}")
# Add PARTITION BY clause
if partition_by:
options_parts.append(f"PARTITION BY {partition_by}")
# Add STREAM_OPTIONS clause
if stream_options:
options_parts.append(f"STREAM_OPTIONS({stream_options})")
# Add notification_definition
if notification_definition:
options_parts.append(notification_definition)
options_clause = " ".join(options_parts) if options_parts else ""
# Build AS subquery clause
as_clause = f"AS {source_sql}" if source_sql else ""
# Construct the complete CREATE STREAM SQL
sql_parts = [
"CREATE STREAM",
if_not_exists_clause,
full_stream_name,
trigger_type,
trigger_table,
options_clause,
into_clause,
output_subtable_clause,
columns_clause,
tags_clause,
as_clause,
]
# Filter out empty parts and join
create_stream_sql = " ".join(filter(None, sql_parts)) + ";"
if use_except:
print(f"create stream sql: {create_stream_sql}")
return create_stream_sql
else:
print(f"create stream sql: {create_stream_sql}")
tdSql.execute(create_stream_sql, queryTimes=3)
time.sleep(self.create_stream_sleep)
return None
def pause_stream(self, stream_name, if_exist=True, if_not_exist=False):
"""pause_stream
Args:
stream_name (str): stream_name
if_exist (bool, optional): Defaults to True.
if_not_exist (bool, optional): Defaults to False.
"""
if_exist_value = "if exists" if if_exist else ""
if_not_exist_value = "if not exists" if if_not_exist else ""
tdSql.execute(
f"pause stream {if_exist_value} {if_not_exist_value} {stream_name}"
)
def resume_stream(
self, stream_name, if_exist=True, if_not_exist=False, ignore_untreated=False
):
"""resume_stream
Args:
stream_name (str): stream_name
if_exist (bool, optional): Defaults to True.
if_not_exist (bool, optional): Defaults to False.
ignore_untreated (bool, optional): Defaults to False.
"""
if_exist_value = "if exists" if if_exist else ""
if_not_exist_value = "if not exists" if if_not_exist else ""
ignore_untreated_value = "ignore untreated" if ignore_untreated else ""
tdSql.execute(
f"resume stream {if_exist_value} {if_not_exist_value} {ignore_untreated_value} {stream_name}"
)
def drop_all_streams(self):
"""drop all streams from all user databases"""
# First get all databases
tdSql.query("show databases")
db_list = list(map(lambda x: x[0], tdSql.queryResult))
# Filter out system databases
user_db_list = []
for db_name in db_list:
if db_name not in ["information_schema", "performance_schema"]:
user_db_list.append(db_name)
# Drop streams from each user database
for db_name in user_db_list:
try:
# Show streams for this specific database
tdSql.query(f"show {db_name}.streams")
stream_name_list = list(map(lambda x: x[0], tdSql.queryResult))
# Drop each stream in this database
for stream_name in stream_name_list:
# Check if stream name already includes database prefix
if "." in stream_name:
# Stream name already has database prefix
tdSql.execute(f"drop stream if exists {stream_name};")
else:
# Add database prefix to stream name
tdSql.execute(f"drop stream if exists {db_name}.{stream_name};")
tdLog.debug(
f"Dropped stream: {stream_name} from database: {db_name}"
)
except Exception as e:
# If database doesn't exist or has no streams, continue with next database
tdLog.debug(
f"No streams found in database {db_name} or database doesn't exist: {e}"
)
continue
tdLog.info("All user database streams have been dropped")
def check_stream_wal_info(self, wal_info):
# This method is defined for the 'info' column of the 'information_schema.ins_stream_tasks'.
# Define the regular expression pattern to match the required format
# This pattern looks for a number followed by an optional space and then a pair of square brackets
# containing two numbers separated by a comma.
pattern = r"(\d+)\s*\[(\d+),\s*(\d+)\]"
# Use the search function from the re module to find a match in the string
match = re.search(pattern, wal_info)
# Check if a match was found
if match:
# Extract the numbers from the matching groups
first_number = int(match.group(1)) # The number before the brackets
second_number = int(match.group(3)) # The second number inside the brackets
# Compare the extracted numbers and return the result
if second_number >= 5:
if first_number >= second_number - 5 and first_number <= second_number:
return True
elif second_number < 5:
if first_number >= second_number - 1 and first_number <= second_number:
return True
# If no match was found, or the pattern does not match the expected format, return False
return False
def check_stream_task_status(
self, stream_name, vgroups, stream_timeout=0, check_wal_info=True
):
"""check stream status
Args:
stream_name (str): stream_name
vgroups (int): vgroups
Returns:
str: status
"""
timeout = self.stream_timeout if stream_timeout is None else stream_timeout
# check stream task rows
sql_task_all = f"select `task_id`,node_id,stream_name,status from information_schema.ins_stream_tasks where stream_name='{stream_name}';"
sql_task_status = f"select distinct(status) from information_schema.ins_stream_tasks where stream_name='{stream_name}'"
tdSql.query(sql_task_all)
tdSql.checkRows(vgroups)
# check stream task status
checktimes = 1
check_stream_success = 0
vgroup_num = 0
while checktimes <= timeout:
tdLog.notice(f"checktimes:{checktimes}")
try:
result_task_all = tdSql.query(sql_task_all, row_tag=True)
result_task_all_rows = tdSql.query(sql_task_all)
result_task_status = tdSql.query(sql_task_status, row_tag=True)
result_task_status_rows = tdSql.query(sql_task_status)
tdLog.notice(
f"Try to check stream status, check times: {checktimes} and stream task list[{check_stream_success}]"
)
print(
f"result_task_status:{result_task_status},result_task_all:{result_task_all}"
)
if result_task_status_rows == 1 and result_task_status == [
("Running",)
]:
if check_wal_info:
for vgroup_num in range(vgroups):
if self.check_stream_wal_info(
result_task_all[vgroup_num][4]
):
check_stream_success += 1
tdLog.info(
f"check stream task list[{check_stream_success}] sucessfully :"
)
else:
check_stream_success = 0
break
else:
check_stream_success = vgroups
if check_stream_success == vgroups:
break
time.sleep(1)
checktimes += 1
vgroup_num = vgroup_num
except Exception as e:
tdLog.notice(
f"Try to check stream status again, check times: {checktimes}"
)
checktimes += 1
tdSql.print_error_frame_info(
result_task_all[vgroup_num],
"status is ready,info is finished and history_task_id is NULL",
sql_task_all,
)
else:
checktimes_end = checktimes - 1
tdLog.notice(
f"it has spend {checktimes_end} for checking stream task status but it failed"
)
if checktimes_end == timeout:
tdSql.print_error_frame_info(
result_task_all[vgroup_num],
"status is ready,info is finished and history_task_id is NULL",
sql_task_all,
)
def drop_db(self, dbname="test"):
"""drop a db
Args:
dbname (str, optional): Defaults to "test".
"""
if dbname[0].isdigit():
tdSql.execute(f"drop database if exists `{dbname}`")
else:
tdSql.execute(f"drop database if exists {dbname}")
def drop_all_db(self):
"""drop all databases"""
tdSql.query("show databases;")
db_list = list(map(lambda x: x[0], tdSql.queryResult))
for dbname in db_list:
if dbname not in self.white_list and "telegraf" not in dbname:
tdSql.execute(f"drop database if exists `{dbname}`")
def time_cast(self, time_value, split_symbol="+"):
"""cast bigint to timestamp
Args:
time_value (bigint): ts
split_symbol (str, optional): split sympol. Defaults to "+".
Returns:
_type_: timestamp
"""
ts_value = str(time_value).split(split_symbol)[0]
if split_symbol in str(time_value):
ts_value_offset = str(time_value).split(split_symbol)[1]
else:
ts_value_offset = "0s"
return f"cast({ts_value} as timestamp){split_symbol}{ts_value_offset}"
def clean_env(self):
"""drop all streams and databases"""
self.drop_all_streams()
self.drop_all_db()
def set_precision_offset(self, precision):
if precision == "ms":
self.offset = 1000
elif precision == "us":
self.offset = 1000000
elif precision == "ns":
self.offset = 1000000000
else:
pass
def genTs(self, precision="ms", ts="", protype="taosc", ns_tag=None):
"""generate ts
Args:
precision (str, optional): db precision. Defaults to "ms".
ts (str, optional): input ts. Defaults to "".
protype (str, optional): "taosc" or "restful". Defaults to "taosc".
ns_tag (_type_, optional): use ns. Defaults to None.
Returns:
timestamp, datetime: timestamp and datetime
"""
if precision == "ns":
if ts == "" or ts is None:
ts = time.time_ns()
else:
ts = ts
if ns_tag is None:
dt = ts
else:
dt = datetime.fromtimestamp(ts // 1000000000)
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000000000)).zfill(9)
)
if protype == "restful":
dt = datetime.fromtimestamp(ts // 1000000000)
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000000000)).zfill(9)
)
else:
if ts == "" or ts is None:
ts = time.time()
else:
ts = ts
if precision == "ms" or precision is None:
ts = int(round(ts * 1000))
dt = datetime.fromtimestamp(ts // 1000)
if protype == "taosc":
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000)).zfill(3)
+ "000"
)
elif protype == "restful":
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000)).zfill(3)
)
else:
pass
elif precision == "us":
ts = int(round(ts * 1000000))
dt = datetime.fromtimestamp(ts // 1000000)
dt = (
dt.strftime("%Y-%m-%d %H:%M:%S")
+ "."
+ str(int(ts % 1000000)).zfill(6)
)
return ts, dt
def sgen_column_type_str(self, column_elm_list):
"""generage column type str
Args:
column_elm_list (list): column_elm_list
"""
self.column_type_str = ""
if column_elm_list is None:
self.column_type_str = self.gen_default_column_str()
else:
for column_elm in column_elm_list:
if "count" in column_elm:
total_count = int(column_elm["count"])
else:
total_count = 1
if total_count > 0:
for _ in range(total_count):
self.column_type_str += f"{self.default_colname_prefix}{self.default_column_index_start_num} {column_elm['type']}, "
if column_elm["type"] in ["varchar", "binary", "nchar"]:
self.column_type_str = (
self.column_type_str.rstrip()[:-1]
+ f"({column_elm['len']}), "
)
self.default_column_index_start_num += 1
else:
continue
self.column_type_str = (
self.default_colts_name
+ " timestamp, "
+ self.column_type_str.rstrip()[:-1]
)
def sgen_tag_type_str(self, tag_elm_list):
"""generage tag type str
Args:
tag_elm_list (list): tag_elm_list
"""
self.tag_type_str = ""
if tag_elm_list is None:
self.tag_type_str = self.gen_default_tag_str()
else:
for tag_elm in tag_elm_list:
if "count" in tag_elm:
total_count = int(tag_elm["count"])
else:
total_count = 1
if total_count > 0:
for _ in range(total_count):
self.tag_type_str += f"{self.default_tagname_prefix}{self.default_tag_index_start_num} {tag_elm['type']}, "
if tag_elm["type"] in ["varchar", "binary", "nchar"]:
self.tag_type_str = (
self.tag_type_str.rstrip()[:-1]
+ f"({tag_elm['len']}), "
)
self.default_tag_index_start_num += 1
else:
continue
self.tag_type_str = self.tag_type_str.rstrip()[:-1]
if self.need_tagts:
self.tag_type_str = (
self.default_tagts_name + " timestamp, " + self.tag_type_str
)
def sgen_tag_value_list(self, tag_elm_list, ts_value=None):
"""generage tag value str
Args:
tag_elm_list (list): _description_
ts_value (timestamp, optional): Defaults to None.
"""
if self.need_tagts:
self.ts_value = self.genTs()[0]
if ts_value is not None:
self.ts_value = ts_value
if tag_elm_list is None:
self.tag_value_list = list(
map(
lambda i: self.gen_random_type_value(
i,
self.default_varchar_length,
self.default_varchar_datatype,
self.default_nchar_length,
self.default_nchar_datatype,
),
self.full_type_list,
)
)
else:
for tag_elm in tag_elm_list:
if "count" in tag_elm:
total_count = int(tag_elm["count"])
else:
total_count = 1
if total_count > 0:
for _ in range(total_count):
if tag_elm["type"] in ["varchar", "binary", "nchar"]:
self.tag_value_list.append(
self.gen_random_type_value(
tag_elm["type"],
tag_elm["len"],
self.default_varchar_datatype,
tag_elm["len"],
self.default_nchar_datatype,
)
)
else:
self.tag_value_list.append(
self.gen_random_type_value(
tag_elm["type"], "", "", "", ""
)
)
else:
continue
# if self.need_tagts and self.ts_value is not None and len(str(self.ts_value)) > 0:
if self.need_tagts:
self.tag_value_list = [self.ts_value] + self.tag_value_list
def screateDb(self, dbname="test", drop_db=True, **kwargs):
"""create database
Args:
dbname (str, optional): Defaults to "test".
drop_db (bool, optional): Defaults to True.
"""
tdLog.info("creating db ...")
db_params = ""
if len(kwargs) > 0:
for param, value in kwargs.items():
if param == "precision":
db_params += f'{param} "{value}" '
else:
db_params += f"{param} {value} "
if drop_db:
self.drop_db(dbname)
tdSql.execute(f"create database if not exists {dbname} {db_params}")
tdSql.execute(f"use {dbname}")
def screate_stable(
self,
dbname=None,
stbname="stb",
use_name="table",
column_elm_list=None,
tag_elm_list=None,
need_tagts=False,
count=1,
default_stbname_prefix="stb",
default_stbname_index_start_num=1,
default_column_index_start_num=1,
default_tag_index_start_num=1,
**kwargs,
):
"""_summary_
Args:
dbname (str, optional): Defaults to None.
stbname (str, optional): Defaults to "stb".
use_name (str, optional): stable/table, Defaults to "table".
column_elm_list (list, optional): use for sgen_column_type_str(), Defaults to None.
tag_elm_list (list, optional): use for sgen_tag_type_str(), Defaults to None.
need_tagts (bool, optional): tag use timestamp, Defaults to False.
count (int, optional): stable count, Defaults to 1.
default_stbname_prefix (str, optional): Defaults to "stb".
default_stbname_index_start_num (int, optional): Defaults to 1.
default_column_index_start_num (int, optional): Defaults to 1.
default_tag_index_start_num (int, optional): Defaults to 1.
"""
tdLog.info("creating stable ...")
if dbname is not None:
self.dbname = dbname
self.need_tagts = need_tagts
self.default_stbname_prefix = default_stbname_prefix
self.default_stbname_index_start_num = default_stbname_index_start_num
self.default_column_index_start_num = default_column_index_start_num
self.default_tag_index_start_num = default_tag_index_start_num
stb_params = ""
if len(kwargs) > 0:
for param, value in kwargs.items():
stb_params += f'{param} "{value}" '
self.sgen_column_type_str(column_elm_list)
self.sgen_tag_type_str(tag_elm_list)
if self.dbname is not None:
stb_name = f"{self.dbname}.{stbname}"
else:
stb_name = stbname
if int(count) <= 1:
create_stable_sql = f"create {use_name} {stb_name} ({self.column_type_str}) tags ({self.tag_type_str}) {stb_params};"
tdSql.execute(create_stable_sql)
else:
for _ in range(count):
create_stable_sql = f"create {use_name} {self.dbname}.{default_stbname_prefix}{default_stbname_index_start_num} ({self.column_type_str}) tags ({self.tag_type_str}) {stb_params};"
default_stbname_index_start_num += 1
tdSql.execute(create_stable_sql)
def screate_ctable(
self,
dbname=None,
stbname=None,
ctbname="ctb",
use_name="table",
tag_elm_list=None,
ts_value=None,
count=1,
default_varchar_datatype="letters",
default_nchar_datatype="letters",
default_ctbname_prefix="ctb",
default_ctbname_index_start_num=1,
**kwargs,
):
"""_summary_
Args:
dbname (str, optional): Defaults to None.
stbname (str, optional): Defaults to None.
ctbname (str, optional): Defaults to "ctb".
use_name (str, optional): Defaults to "table".
tag_elm_list (list, optional): use for sgen_tag_type_str(), Defaults to None.
ts_value (timestamp, optional): Defaults to None.
count (int, optional): ctb count, Defaults to 1.
default_varchar_datatype (str, optional): Defaults to "letters".
default_nchar_datatype (str, optional): Defaults to "letters".
default_ctbname_prefix (str, optional): Defaults to "ctb".
default_ctbname_index_start_num (int, optional): Defaults to 1.
"""
tdLog.info("creating childtable ...")
self.default_varchar_datatype = default_varchar_datatype
self.default_nchar_datatype = default_nchar_datatype
self.default_ctbname_prefix = default_ctbname_prefix
self.default_ctbname_index_start_num = default_ctbname_index_start_num
ctb_params = ""
if len(kwargs) > 0:
for param, value in kwargs.items():
ctb_params += f'{param} "{value}" '
self.sgen_tag_value_list(tag_elm_list, ts_value)
tag_value_str = ""
# tag_value_str = ", ".join(str(v) for v in self.tag_value_list)
for tag_value in self.tag_value_list:
if isinstance(tag_value, str):
tag_value_str += f'"{tag_value}", '
else:
tag_value_str += f"{tag_value}, "
tag_value_str = tag_value_str.rstrip()[:-1]
if dbname is not None:
self.dbname = dbname
ctb_name = f"{self.dbname}.{ctbname}"
else:
ctb_name = ctbname
if stbname is not None:
stb_name = stbname
if int(count) <= 1:
create_ctable_sql = f"create {use_name} {ctb_name} using {stb_name} tags ({tag_value_str}) {ctb_params};"
tdSql.execute(create_ctable_sql)
else:
for _ in range(count):
create_stable_sql = f"create {use_name} {self.dbname}.{default_ctbname_prefix}{default_ctbname_index_start_num} using {self.stb_name} tags ({tag_value_str}) {ctb_params};"
default_ctbname_index_start_num += 1
tdSql.execute(create_stable_sql)
def sgen_column_value_list(
self,
column_elm_list,
need_null,
ts_value=None,
additional_ts=None,
custom_col_index=None,
col_value_type=None,
force_pk_val=None,
):
"""_summary_
Args:
column_elm_list (list): gen_random_type_value()
need_null (bool): if insert null
ts_value (timestamp, optional): Defaults to None.
"""
self.column_value_list = list()
self.ts_value = self.genTs()[0]
if additional_ts is not None:
self.additional_ts = self.genTs(additional_ts=additional_ts)[2]
if ts_value is not None:
self.ts_value = ts_value
if column_elm_list is None:
self.column_value_list = list(
map(
lambda i: self.gen_random_type_value(
i,
self.default_varchar_length,
self.default_varchar_datatype,
self.default_nchar_length,
self.default_nchar_datatype,
),
self.full_type_list,
)
)
else:
for column_elm in column_elm_list:
if "count" in column_elm:
total_count = int(column_elm["count"])
else:
total_count = 1
if total_count > 0:
for _ in range(total_count):
if column_elm["type"] in ["varchar", "binary", "nchar"]:
self.column_value_list.append(
self.gen_random_type_value(
column_elm["type"],
column_elm["len"],
self.default_varchar_datatype,
column_elm["len"],
self.default_nchar_datatype,
)
)
else:
self.column_value_list.append(
self.gen_random_type_value(
column_elm["type"], "", "", "", ""
)
)
else:
continue
if need_null:
for i in range(int(len(self.column_value_list) / 2)):
index_num = random.randint(0, len(self.column_value_list) - 1)
self.column_value_list[index_num] = None
if custom_col_index is not None:
if col_value_type == "Random":
pass
elif col_value_type == "Incremental":
self.column_value_list[custom_col_index] = self.custom_col_val
self.custom_col_val += 1
elif col_value_type == "Part_equal":
self.column_value_list[custom_col_index] = random.choice(
self.part_val_list
)
self.column_value_list = (
[self.ts_value] + [self.additional_ts] + self.column_value_list
if additional_ts is not None
else [self.ts_value] + self.column_value_list
)
if col_value_type == "Incremental" and custom_col_index == 1:
self.column_value_list[custom_col_index] = (
self.custom_col_val if force_pk_val is None else force_pk_val
)
if col_value_type == "Part_equal" and custom_col_index == 1:
self.column_value_list[custom_col_index] = (
random.randint(0, self.custom_col_val)
if force_pk_val is None
else force_pk_val
)
def screate_table(
self,
dbname=None,
tbname="tb",
use_name="table",
column_elm_list=None,
count=1,
default_tbname_prefix="tb",
default_tbname_index_start_num=1,
default_column_index_start_num=1,
**kwargs,
):
"""create ctable
Args:
dbname (str, optional): Defaults to None.
tbname (str, optional): Defaults to "tb".
use_name (str, optional): Defaults to "table".
column_elm_list (list, optional): Defaults to None.
count (int, optional): Defaults to 1.
default_tbname_prefix (str, optional): Defaults to "tb".
default_tbname_index_start_num (int, optional): Defaults to 1.
default_column_index_start_num (int, optional): Defaults to 1.
"""
tdLog.info("creating table ...")
if dbname is not None:
self.dbname = dbname
self.default_tbname_prefix = default_tbname_prefix
self.default_tbname_index_start_num = default_tbname_index_start_num
self.default_column_index_start_num = default_column_index_start_num
tb_params = ""
if len(kwargs) > 0:
for param, value in kwargs.items():
tb_params += f'{param} "{value}" '
self.sgen_column_type_str(column_elm_list)
if self.dbname is not None:
tb_name = f"{self.dbname}.{tbname}"
else:
tb_name = tbname
if int(count) <= 1:
create_table_sql = (
f"create {use_name} {tb_name} ({self.column_type_str}) {tb_params};"
)
tdSql.execute(create_table_sql)
else:
for _ in range(count):
create_table_sql = f"create {use_name} {self.dbname}.{default_tbname_prefix}{default_tbname_index_start_num} ({self.column_type_str}) {tb_params};"
default_tbname_index_start_num += 1
tdSql.execute(create_table_sql)
def sinsert_rows(
self,
dbname=None,
tbname=None,
column_ele_list=None,
ts_value=None,
count=1,
need_null=False,
custom_col_index=None,
col_value_type="random",
):
"""insert rows
Args:
dbname (str, optional): Defaults to None.
tbname (str, optional): Defaults to None.
column_ele_list (list, optional): Defaults to None.
ts_value (timestamp, optional): Defaults to None.
count (int, optional): Defaults to 1.
need_null (bool, optional): Defaults to False.
"""
tdLog.info("stream inserting ...")
if dbname is not None:
self.dbname = dbname
if tbname is not None:
self.tbname = f"{self.dbname}.{tbname}"
else:
if tbname is not None:
self.tbname = tbname
self.sgen_column_value_list(
column_ele_list,
need_null,
ts_value,
custom_col_index=custom_col_index,
col_value_type=col_value_type,
)
# column_value_str = ", ".join(str(v) for v in self.column_value_list)
column_value_str = ""
for column_value in self.column_value_list:
if column_value is None:
column_value_str += "Null, "
elif (
isinstance(column_value, str)
and "+" not in column_value
and "-" not in column_value
):
column_value_str += f'"{column_value}", '
else:
column_value_str += f"{column_value}, "
column_value_str = column_value_str.rstrip()[:-1]
if int(count) <= 1:
insert_sql = f"insert into {self.tbname} values ({column_value_str});"
tdSql.execute(insert_sql)
else:
for num in range(count):
ts_value = self.genTs()[0]
self.sgen_column_value_list(
column_ele_list,
need_null,
f"{ts_value}+{num}s",
custom_col_index=custom_col_index,
col_value_type=col_value_type,
)
column_value_str = ""
for column_value in self.column_value_list:
if column_value is None:
column_value_str += "Null, "
elif isinstance(column_value, str) and "+" not in column_value:
column_value_str += f'"{column_value}", '
else:
column_value_str += f"{column_value}, "
column_value_str = column_value_str.rstrip()[:-1]
insert_sql = f"insert into {self.tbname} values ({column_value_str});"
tdSql.execute(insert_sql)
def sdelete_rows(
self, dbname=None, tbname=None, start_ts=None, end_ts=None, ts_key=None
):
"""delete rows
Args:
dbname (str, optional): Defaults to None.
tbname (str, optional): Defaults to None.
start_ts (timestamp, optional): range start. Defaults to None.
end_ts (timestamp, optional): range end. Defaults to None.
ts_key (str, optional): timestamp column name. Defaults to None.
"""
if dbname is not None:
self.dbname = dbname
if tbname is not None:
self.tbname = f"{self.dbname}.{tbname}"
else:
if tbname is not None:
self.tbname = tbname
if ts_key is None:
ts_col_name = self.default_colts_name
else:
ts_col_name = ts_key
base_del_sql = f"delete from {self.tbname} "
if end_ts is not None:
if ":" in start_ts and "-" in start_ts:
start_ts = f"{start_ts}"
if ":" in end_ts and "-" in end_ts:
end_ts = f"{end_ts}"
base_del_sql += f"where {ts_col_name} between {start_ts} and {end_ts};"
else:
if start_ts is not None:
if ":" in start_ts and "-" in start_ts:
start_ts = f"{start_ts}"
base_del_sql += f"where {ts_col_name} = {start_ts};"
tdSql.execute(base_del_sql)
def check_stream_field_type(self, sql, input_function):
"""confirm stream field
Args:
sql (str): input sql
input_function (str): scalar
"""
tdSql.query(sql)
res = tdSql.queryResult
if input_function in [
"acos",
"asin",
"atan",
"cos",
"log",
"pow",
"sin",
"sqrt",
"tan",
]:
tdSql.checkEqual(res[1][1], "DOUBLE")
tdSql.checkEqual(res[2][1], "DOUBLE")
elif input_function in ["lower", "ltrim", "rtrim", "upper"]:
tdSql.checkEqual(res[1][1], "VARCHAR")
tdSql.checkEqual(res[2][1], "VARCHAR")
tdSql.checkEqual(res[3][1], "NCHAR")
elif input_function in ["char_length", "length"]:
tdSql.checkEqual(res[1][1], "BIGINT")
tdSql.checkEqual(res[2][1], "BIGINT")
tdSql.checkEqual(res[3][1], "BIGINT")
elif input_function in ["concat", "concat_ws"]:
tdSql.checkEqual(res[1][1], "VARCHAR")
tdSql.checkEqual(res[2][1], "NCHAR")
tdSql.checkEqual(res[3][1], "NCHAR")
tdSql.checkEqual(res[4][1], "NCHAR")
elif input_function in ["substr"]:
tdSql.checkEqual(res[1][1], "VARCHAR")
tdSql.checkEqual(res[2][1], "VARCHAR")
tdSql.checkEqual(res[3][1], "VARCHAR")
tdSql.checkEqual(res[4][1], "NCHAR")
else:
tdSql.checkEqual(res[1][1], "INT")
tdSql.checkEqual(res[2][1], "DOUBLE")
def round_handle(self, input_list):
"""round list elem
Args:
input_list (list): input value list
Returns:
_type_: round list
"""
tdLog.info("round rows ...")
final_list = list()
for i in input_list:
tmpl = list()
for j in i:
if type(j) != datetime and type(j) != str:
tmpl.append(round(j, 1))
else:
tmpl.append(j)
final_list.append(tmpl)
return final_list
def float_handle(self, input_list):
"""float list elem
Args:
input_list (list): input value list
Returns:
_type_: float list
"""
tdLog.info("float rows ...")
final_list = list()
for i in input_list:
tmpl = list()
for j_i, j_v in enumerate(i):
if (
type(j_v) != datetime
and j_v is not None
and str(j_v).isdigit()
and j_i <= 12
):
tmpl.append(float(j_v))
else:
tmpl.append(j_v)
final_list.append(tuple(tmpl))
return final_list
def str_ts_trans_bigint(self, str_ts):
"""trans str ts to bigint
Args:
str_ts (str): human-date
Returns:
bigint: bigint-ts
"""
tdSql.query(f"select cast({str_ts} as bigint)")
return tdSql.queryResult[0][0]
def cast_query_data(self, query_data):
"""cast query-result for existed-stb
Args:
query_data (list): query data list
Returns:
list: new list after cast
"""
tdLog.info("cast query data ...")
col_type_list = self.column_type_str.split(",")
tag_type_list = self.tag_type_str.split(",")
col_tag_type_list = col_type_list + tag_type_list
nl = list()
for query_data_t in query_data:
query_data_l = list(query_data_t)
for i, v in enumerate(query_data_l):
if v is not None:
if (
" ".join(col_tag_type_list[i].strip().split(" ")[1:])
== "nchar(6)"
):
tdSql.query(f'select cast("{v}" as binary(6))')
else:
tdSql.query(
f'select cast("{v}" as {" ".join(col_tag_type_list[i].strip().split(" ")[1:])})'
)
query_data_l[i] = tdSql.queryResult[0][0]
else:
query_data_l[i] = v
nl.append(tuple(query_data_l))
return nl
def trans_time_to_s(self, runtime):
"""trans time to s
Args:
runtime (str): 1d/1h/1m...
Returns:
int: second
"""
if "d" in str(runtime).lower():
d_num = re.findall(r"\d+\.?\d*", runtime.replace(" ", ""))[0]
s_num = float(d_num) * 24 * 60 * 60
elif "h" in str(runtime).lower():
h_num = re.findall(r"\d+\.?\d*", runtime.replace(" ", ""))[0]
s_num = float(h_num) * 60 * 60
elif "m" in str(runtime).lower():
m_num = re.findall(r"\d+\.?\d*", runtime.replace(" ", ""))[0]
s_num = float(m_num) * 60
elif "s" in str(runtime).lower():
s_num = re.findall(r"\d+\.?\d*", runtime.replace(" ", ""))[0]
else:
s_num = 60
return int(s_num)
def check_query_data(
self,
sql1,
sql2,
sorted=False,
fill_value=None,
tag_value_list=None,
defined_tag_count=None,
partition=True,
use_exist_stb=False,
subtable=None,
reverse_check=False,
):
"""confirm query result
Args:
sql1 (str): select ....
sql2 (str): select ....
sorted (bool, optional): if sort result list. Defaults to False.
fill_value (str, optional): fill. Defaults to None.
tag_value_list (list, optional): Defaults to None.
defined_tag_count (int, optional): Defaults to None.
partition (bool, optional): Defaults to True.
use_exist_stb (bool, optional): Defaults to False.
subtable (str, optional): Defaults to None.
reverse_check (bool, optional): not equal. Defaults to False.
Returns:
bool: False if failed
"""
tdLog.info("checking query data ...")
if tag_value_list:
dvalue = len(self.tag_type_str.split(",")) - defined_tag_count
tdSql.query(sql1)
res1 = tdSql.queryResult
tdSql.query(sql2)
res2 = (
self.cast_query_data(tdSql.queryResult)
if tag_value_list or use_exist_stb
else tdSql.queryResult
)
tdSql.sql = sql1
new_list = list()
if tag_value_list:
res1 = self.float_handle(res1)
res2 = self.float_handle(res2)
for i, v in enumerate(res2):
if i < len(tag_value_list):
if partition:
new_list.append(
tuple(
list(v)[: -(dvalue + defined_tag_count)]
+ list(tag_value_list[i])
+ [None] * dvalue
)
)
else:
new_list.append(
tuple(
list(v)[: -(dvalue + defined_tag_count)]
+ [None] * len(self.tag_type_str.split(","))
)
)
res2 = new_list
else:
if use_exist_stb:
res1 = self.float_handle(res1)
res2 = self.float_handle(res2)
for i, v in enumerate(res2):
new_list.append(
tuple(
list(v)[:-(13)] + [None] * len(self.tag_type_str.split(","))
)
)
res2 = new_list
latency = 0
if sorted:
res1.sort()
res2.sort()
if fill_value == "LINEAR":
res1 = self.round_handle(res1)
res2 = self.round_handle(res2)
if not reverse_check:
while res1 != res2:
tdLog.info("query retrying ...")
new_list = list()
tdSql.query(sql1)
res1 = tdSql.queryResult
tdSql.query(sql2)
# res2 = tdSql.queryResult
res2 = (
self.cast_query_data(tdSql.queryResult)
if tag_value_list or use_exist_stb
else tdSql.queryResult
)
tdSql.sql = sql1
if tag_value_list:
res1 = self.float_handle(res1)
res2 = self.float_handle(res2)
for i, v in enumerate(res2):
if i < len(tag_value_list):
if partition:
new_list.append(
tuple(
list(v)[: -(dvalue + defined_tag_count)]
+ list(tag_value_list[i])
+ [None] * dvalue
)
)
else:
new_list.append(
tuple(
list(v)[: -(dvalue + defined_tag_count)]
+ [None] * len(self.tag_type_str.split(","))
)
)
res2 = new_list
else:
if use_exist_stb:
res1 = self.float_handle(res1)
res2 = self.float_handle(res2)
for i, v in enumerate(res2):
new_list.append(
tuple(
list(v)[:-(13)]
+ [None] * len(self.tag_type_str.split(","))
)
)
res2 = new_list
if sorted or tag_value_list:
res1.sort()
res2.sort()
if fill_value == "LINEAR":
res1 = self.round_handle(res1)
res2 = self.round_handle(res2)
if latency < self.stream_timeout:
latency += 0.2
time.sleep(0.2)
else:
if latency == 0:
return False
tdSql.checkEqual(res1, res2)
# tdSql.checkEqual(res1, res2) if not reverse_check else tdSql.checkNotEqual(res1, res2)
else:
while res1 == res2:
tdLog.info("query retrying ...")
new_list = list()
tdSql.query(sql1)
res1 = tdSql.queryResult
tdSql.query(sql2)
# res2 = tdSql.queryResult
res2 = (
self.cast_query_data(tdSql.queryResult)
if tag_value_list or use_exist_stb
else tdSql.queryResult
)
tdSql.sql = sql1
if tag_value_list:
res1 = self.float_handle(res1)
res2 = self.float_handle(res2)
for i, v in enumerate(res2):
if i < len(tag_value_list):
if partition:
new_list.append(
tuple(
list(v)[: -(dvalue + defined_tag_count)]
+ list(tag_value_list[i])
+ [None] * dvalue
)
)
else:
new_list.append(
tuple(
list(v)[: -(dvalue + defined_tag_count)]
+ [None] * len(self.tag_type_str.split(","))
)
)
res2 = new_list
else:
if use_exist_stb:
res1 = self.float_handle(res1)
res2 = self.float_handle(res2)
for i, v in enumerate(res2):
new_list.append(
tuple(
list(v)[:-(13)]
+ [None] * len(self.tag_type_str.split(","))
)
)
res2 = new_list
if sorted or tag_value_list:
res1.sort()
res2.sort()
if fill_value == "LINEAR":
res1 = self.round_handle(res1)
res2 = self.round_handle(res2)
if latency < self.stream_timeout:
latency += 0.2
time.sleep(0.2)
else:
if latency == 0:
return False
tdSql.checkNotEqual(res1, res2)
# tdSql.checkEqual(res1, res2) if not reverse_check else tdSql.checkNotEqual(res1, res2)
def check_stream_res(self, sql, expected_res, max_delay):
"""confirm stream result
Args:
sql (str): select ...
expected_res (str): expected result
max_delay (int): max_delay value
Returns:
bool: False if failed
"""
tdSql.query(sql)
latency = 0
while tdSql.queryRows != expected_res:
tdSql.query(sql)
if latency < self.stream_timeout:
latency += 0.2
time.sleep(0.2)
else:
if max_delay is not None:
if latency == 0:
return False
tdSql.checkEqual(tdSql.queryRows, expected_res)
def check_stream(self, sql1, sql2, expected_count, max_delay=None):
"""confirm stream
Args:
sql1 (str): select ...
sql2 (str): select ...
expected_count (int): expected_count
max_delay (int, optional): max_delay value. Defaults to None.
"""
self.check_stream_res(sql1, expected_count, max_delay)
self.check_query_data(sql1, sql2)
def cal_watermark_window_close_session_endts(
self, start_ts, watermark=None, session=None
):
"""cal endts for close window
Args:
start_ts (epoch time): self.date_time
watermark (int, optional): > session. Defaults to None.
session (int, optional): Defaults to None.
Returns:
int: as followed
"""
if watermark is not None:
return start_ts + watermark * self.offset + 1
else:
return start_ts + session * self.offset + 1
def cal_watermark_window_close_interval_endts(
self, start_ts, interval, watermark=None
):
"""cal endts for close window
Args:
start_ts (epoch time): self.date_time
interval (int): [s]
watermark (int, optional): [s]. Defaults to None.
Returns:
_type_: _description_
"""
if watermark is not None:
return (
int(start_ts / self.offset) * self.offset
+ (interval - (int(start_ts / self.offset)) % interval) * self.offset
+ watermark * self.offset
)
else:
return (
int(start_ts / self.offset) * self.offset
+ (interval - (int(start_ts / self.offset)) % interval) * self.offset
)
def update_delete_history_data(self, delete):
"""update and delete history data
Args:
delete (bool): True/False
"""
self.sinsert_rows(tbname=self.ctb_name, ts_value=self.record_history_ts)
self.sinsert_rows(tbname=self.tb_name, ts_value=self.record_history_ts)
if delete:
self.sdelete_rows(
tbname=self.ctb_name,
start_ts=self.time_cast(self.record_history_ts, "-"),
)
self.sdelete_rows(
tbname=self.tb_name,
start_ts=self.time_cast(self.record_history_ts, "-"),
)
def get_timestamp_n_days_later(self, n=30):
"""
Get the timestamp of a date n days later from the current date.
Args:
n (int): Number of days to add to the current date. Default is 30.
Returns:
int: Timestamp of the date n days later, in milliseconds.
"""
now = datetime.now()
thirty_days_later = now + timedelta(days=n)
timestamp_thirty_days_later = thirty_days_later.timestamp()
return int(timestamp_thirty_days_later * 1000)
def create_snode_if_not_exists(self, dnode_id=1):
"""Create snode if not exists
Args:
dnode_id (int, optional): The dnode ID to create snode on. Defaults to 1.
Returns:
bool: True if snode exists or created successfully, False if creation failed
"""
try:
# Check if snode already exists
tdSql.query("show snodes")
snode_count = tdSql.queryRows
if snode_count > 0:
tdLog.info(f"Snode already exists, found {snode_count} snode(s)")
return True
# No snode exists, create one
tdLog.info(f"No snode found, creating snode on dnode {dnode_id}")
tdSql.execute(f"create snode on dnode {dnode_id}")
# Verify snode creation
tdSql.query("show snodes")
if tdSql.queryRows > 0:
tdLog.info(f"Snode created successfully on dnode {dnode_id}")
return True
else:
tdLog.error(f"Failed to create snode on dnode {dnode_id}")
return False
except Exception as e:
tdLog.error(f"Error creating snode: {e}")
return False
def ensure_snode_ready(self, dnode_id=1, timeout=30):
"""Ensure snode is created and ready
Args:
dnode_id (int, optional): The dnode ID to create snode on. Defaults to 1.
timeout (int, optional): Maximum wait time in seconds. Defaults to 30.
Returns:
bool: True if snode is ready, False if timeout or creation failed
"""
try:
# First try to create snode if not exists
if not self.create_snode_if_not_exists(dnode_id):
return False
# Wait for snode to be ready
wait_time = 0
while wait_time < timeout:
tdSql.query("show snodes")
if tdSql.queryRows > 0:
# Check if snode status is ready (if status column exists)
snode_info = tdSql.queryResult
# tdLog.info(f"Snode info: {snode_info}")
# Assuming snode is ready if it appears in show snodes
# You might need to adjust this based on actual snode status checking
tdLog.info(f"Snode is ready after {wait_time} seconds")
return True
time.sleep(1)
wait_time += 1
tdLog.error(f"Snode not ready after {timeout} seconds")
return False
except Exception as e:
tdLog.error(f"Error ensuring snode ready: {e}")
return False
def drop_snode(self, snode_id=None):
"""Drop snode
Args:
snode_id (int, optional): Specific snode ID to drop. If None, drops all snodes.
Returns:
bool: True if successful, False otherwise
"""
try:
tdSql.query("show snodes")
if tdSql.queryRows == 0:
tdLog.info("No snode exists to drop")
return True
if snode_id is not None:
# Drop specific snode
tdSql.execute(f"drop snode {snode_id}")
tdLog.info(f"Dropped snode {snode_id}")
else:
# Drop all snodes
snode_list = list(map(lambda x: x[0], tdSql.queryResult))
for snode in snode_list:
tdSql.execute(f"drop snode {snode}")
tdLog.info(f"Dropped snode {snode}")
return True
except Exception as e:
tdLog.error(f"Error dropping snode: {e}")
return False
def prepare_data(
self,
interval=None,
watermark=None,
session=None,
state_window=None,
state_window_max=127,
interation=3,
range_count=None,
precision="ms",
fill_history_value=0,
ext_stb=None,
custom_col_index=None,
col_value_type="random",
):
"""prepare stream data
Args:
interval (int, optional): Defaults to None.
watermark (int, optional): Defaults to None.
session (int, optional): Defaults to None.
state_window (str, optional): Defaults to None.
state_window_max (int, optional): Defaults to 127.
interation (int, optional): Defaults to 3.
range_count (int, optional): Defaults to None.
precision (str, optional): Defaults to "ms".
fill_history_value (int, optional): Defaults to 0.
ext_stb (bool, optional): Defaults to None.
"""
self.clean_env()
self.dataDict = {
"stb_name": f"{self.case_name}_stb",
"ctb_name": f"{self.case_name}_ct1",
"tb_name": f"{self.case_name}_tb1",
"ext_stb_name": f"ext_{self.case_name}_stb",
"ext_ctb_name": f"ext_{self.case_name}_ct1",
"ext_tb_name": f"ext_{self.case_name}_tb1",
"interval": interval,
"watermark": watermark,
"session": session,
"state_window": state_window,
"state_window_max": state_window_max,
"iteration": interation,
"range_count": range_count,
"start_ts": 1655903478508,
}
if range_count is not None:
self.range_count = range_count
if precision is not None:
self.precision = precision
self.set_precision_offset(self.precision)
self.stb_name = self.dataDict["stb_name"]
self.ctb_name = self.dataDict["ctb_name"]
self.tb_name = self.dataDict["tb_name"]
self.ext_stb_name = self.dataDict["ext_stb_name"]
self.ext_ctb_name = self.dataDict["ext_ctb_name"]
self.ext_tb_name = self.dataDict["ext_tb_name"]
self.stb_stream_des_table = f"{self.stb_name}{self.des_table_suffix}"
self.ctb_stream_des_table = f"{self.ctb_name}{self.des_table_suffix}"
self.tb_stream_des_table = f"{self.tb_name}{self.des_table_suffix}"
self.ext_stb_stream_des_table = f"{self.ext_stb_name}{self.des_table_suffix}"
self.ext_ctb_stream_des_table = f"{self.ext_ctb_name}{self.des_table_suffix}"
self.ext_tb_stream_des_table = f"{self.ext_tb_name}{self.des_table_suffix}"
self.date_time = self.genTs(precision=self.precision)[0]
self.screateDb(dbname=self.dbname, precision=self.precision)
if ext_stb:
self.screate_stable(
dbname=self.dbname, stbname=self.ext_stb_stream_des_table
)
self.screate_ctable(
dbname=self.dbname,
stbname=self.ext_stb_stream_des_table,
ctbname=self.ext_ctb_stream_des_table,
)
self.screate_table(dbname=self.dbname, tbname=self.ext_tb_stream_des_table)
self.screate_stable(dbname=self.dbname, stbname=self.stb_name)
self.screate_ctable(
dbname=self.dbname, stbname=self.stb_name, ctbname=self.ctb_name
)
self.screate_table(dbname=self.dbname, tbname=self.tb_name)
if fill_history_value == 1:
for i in range(self.range_count):
ts_value = str(self.date_time) + f"-{self.default_interval * (i + 1)}s"
self.sinsert_rows(
tbname=self.ctb_name,
ts_value=ts_value,
custom_col_index=custom_col_index,
col_value_type=col_value_type,
)
self.sinsert_rows(
tbname=self.tb_name,
ts_value=ts_value,
custom_col_index=custom_col_index,
col_value_type=col_value_type,
)
if i == 1:
self.record_history_ts = ts_value
def generate_query_result_file(self, test_case, idx, sql):
import shlex
self.query_result_file = f"./{test_case}.{idx}.csv"
cfgPath = self.getClientCfgPath()
# Construct command parameters to avoid platform compatibility issues
cmd = ["taos", "-c", cfgPath, "-s", sql]
try:
# Capture output
result = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8", errors="ignore", shell=False)
output = result.stdout.splitlines()
except Exception as e:
tdLog.error(f"Failed to run taos command: {e}")
output = []
# Filter out unwanted lines
ignore_patterns = [
"Query OK",
"Copyright",
"Welcome to the TDengine TSDB Command",
"Welcome to the TDengine Command Line Interface"
]
filtered = [line for line in output if not any(pat in line for pat in ignore_patterns)]
# Write to file
with open(self.query_result_file, "w", encoding="utf-8") as fout:
for line in filtered:
fout.write(line.rstrip("\r\n") + "\n")
return self.query_result_file
def run_sql(self, sql, db):
tdsql = self.newTdSql()
if db:
try:
tdsql.execute(f"USE {db};")
except Exception as e:
tdLog.error(f"USE数据库失败: {db}\n{e}")
try:
tdsql.execute_ignore_error(sql)
except Exception as e:
tdLog.error(f"SQL执行失败: {sql}\n{e}")
def execute_query_file(self, inputfile, max_workers=8):
# Normalize path to support Windows
inputfile = os.path.normpath(inputfile)
if not os.path.exists(inputfile):
tdLog.exit(f"Input file '{inputfile}' does not exist.")
return
tdLog.info(f"Executing query file: {inputfile}")
# Try multiple encodings to support different platforms
lines = []
for encoding in ['utf-8', 'gbk', 'utf-8-sig', 'latin-1']:
try:
with open(inputfile, "r", encoding=encoding, newline=None) as f:
lines = [line.strip() for line in f if line.strip()]
break
except (UnicodeDecodeError, LookupError):
continue
if not lines:
tdLog.exit(f"Failed to read file '{inputfile}' with supported encodings.")
return
# Assume the first line is a use statement
db = lines[0].split()[1].rstrip(";")
sql_lines = [line.replace("\\G", "").rstrip(";") + ";" for line in lines[1:]]
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
executor.map(lambda sql: self.run_sql(sql, db), sql_lines)
def generate_query_result(self, inputfile, test_case):
if not os.path.exists(inputfile):
tdLog.exit(f"Input file '{inputfile}' does not exist.")
else:
self.query_result_file = f"./temp_{test_case}.result"
cfgPath = self.getClientCfgPath()
tdLog.info(
f"Generating query result file: {self.query_result_file} using input file: {inputfile}"
)
if platform.system().lower() == "windows":
# Filter taos> lines
os.system(
f"taos -c {cfgPath} -f {inputfile} "
"| grep -v 'Query OK'|grep -v 'Copyright'| grep -Ev 'Welcome to the TDengine (TSDB )?Command' "
"| grep -v 'Exec cost:' "
"| sed -E 's/[[:space:]]*\\([0-9]+\\.[0-9]+s\\)/ /g' "
"| sed -E 's/cost=[0-9]+\\.[0-9]+\\.\\.[0-9]+\\.[0-9]+//g' "
"| sed -E 's/cost=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)\\.\\.[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/file_load_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/file_load_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/stt_load_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/stt_load_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/mem_load_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/mem_load_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/sma_load_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/sma_load_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/composed_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/composed_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/slowest_vgroup_id=[0-9]+//g' "
"| sed -E 's/fetch_cost=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/fetch_cost=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/fetch_times=[0-9]+\\.[0-9]+\\([0-9]+\\)//g' "
"| sed -E 's/fetch_times=[0-9]+//g' "
"| sed -E 's/slow_deviation=[0-9]+\\.[0-9]+%//g' "
"| sed -E 's/cost_ratio=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/data_deviation=-?[0-9]+\\.[0-9]+%//g' "
"| sed -E 's/Planning Time: [0-9]+\\.[0-9]+ ms//g' "
"| sed -E 's/Execution Time: [0-9]+\\.[0-9]+ ms//g' "
"| sed -E 's/max_row_task=[0-9]+, //g' "
f"> {self.query_result_file}.raw "
)
time.sleep(1)
with (
open(f"{self.query_result_file}.raw", "r", encoding="utf-8") as fin,
open(self.query_result_file, "w", encoding="utf-8") as fout,
):
for line in fin:
stripped = line.rstrip()
# Skip lines that are entirely taos> or taos> followed by whitespace
if re.match(r"^taos>\s*$", stripped):
continue
# Remove trailing whitespace from lines starting with taos>
if stripped.startswith("taos>"):
fout.write(stripped + "\n")
else:
fout.write(line)
os.system(f"rm -f {self.query_result_file}.raw")
else:
os.system(
f"taos -c {cfgPath} -f {inputfile} "
"| grep -v 'Query OK'|grep -v 'Copyright'| grep -Ev 'Welcome to the TDengine (TSDB )?Command' "
"| grep -v 'Exec cost:' "
"| sed -E 's/[[:space:]]*\\([0-9]+\\.[0-9]+s\\)/ /g' "
# cost=0.000..1.111
"| sed -E 's/cost=[0-9]+\\.[0-9]+\\.\\.[0-9]+\\.[0-9]+//g' "
# cost=0.000(0.000)..1.111(1.111)
"| sed -E 's/cost=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)\\.\\.[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/file_load_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/file_load_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/stt_load_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/stt_load_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/mem_load_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/mem_load_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/sma_load_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/sma_load_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/composed_elapsed=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/composed_elapsed=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/slowest_vgroup_id=[0-9]+//g' "
"| sed -E 's/fetch_cost=[0-9]+\\.[0-9]+\\([0-9]+\\.[0-9]+\\)//g' "
"| sed -E 's/fetch_cost=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/fetch_times=[0-9]+\\.[0-9]+\\([0-9]+\\)//g' "
"| sed -E 's/fetch_times=[0-9]+//g' "
"| sed -E 's/slow_deviation=[0-9]+\\.[0-9]+%//g' "
"| sed -E 's/cost_ratio=[0-9]+\\.[0-9]+//g' "
"| sed -E 's/data_deviation=-?[0-9]+\\.[0-9]+%//g' "
"| sed -E 's/Planning Time: [0-9]+\\.[0-9]+ ms//g' "
"| sed -E 's/Execution Time: [0-9]+\\.[0-9]+ ms//g' "
"| tr -d '\\r' "
f"> {self.query_result_file}"
)
return self.query_result_file
def _get_numeric_compare_tolerance(self, token1, token2, float_tolerance):
if float_tolerance > 0.0:
return Decimal(str(float_tolerance))
def count_decimal_places(token):
mantissa = token.lower().split("e", 1)[0]
if "." not in mantissa:
return 0
return len(mantissa.split(".", 1)[1])
precision = max(count_decimal_places(token1), count_decimal_places(token2))
if precision <= 0:
return Decimal("0")
return Decimal(1).scaleb(-precision)
def _normalize_result_line_for_compare(self, line):
"""Normalize CLI-only suffixes before answer/result file comparison.
Args:
line: A single line from an answer or result file.
Returns:
The normalized line with runtime-only Windows suffixes removed.
"""
normalized = line.rstrip()
normalized = re.sub(r"\s*\([0-9]+\.[0-9]+s\)$", "", normalized)
normalized = re.sub(r"cost=[0-9]+\.[0-9]+\.\.[0-9]+\.[0-9]+", "", normalized)
normalized = re.sub(r"Planning Time: [0-9]+\.[0-9]+ ms", "", normalized)
normalized = re.sub(r"Execution Time: [0-9]+\.[0-9]+ ms", "", normalized)
normalized = re.sub(r"max_row_task=[0-9]+, ", "", normalized)
return normalized.rstrip()
def _compare_normalized_result_lines(self, file1, file2):
"""Compare result files after stripping platform-specific CLI noise.
Args:
file1: Expected result file path.
file2: Actual result file path.
Returns:
True when the normalized result lines are identical.
"""
with open(file1, "r", encoding="utf-8", errors="ignore") as f1:
lines1 = f1.read().splitlines()
with open(file2, "r", encoding="utf-8", errors="ignore") as f2:
lines2 = f2.read().splitlines()
if len(lines1) != len(lines2):
return False
for line1, line2 in zip(lines1, lines2):
if self._normalize_result_line_for_compare(
line1
) != self._normalize_result_line_for_compare(line2):
return False
return True
def _compare_file_lines_with_float_tolerance(self, file1, file2, float_tolerance):
number_pattern = re.compile(r"[-+]?(?:\d+\.\d+|\d+|\.\d+)(?:[eE][-+]?\d+)?")
with open(file1, "r", encoding="utf-8", errors="ignore") as f1:
lines1 = f1.read().splitlines()
with open(file2, "r", encoding="utf-8", errors="ignore") as f2:
lines2 = f2.read().splitlines()
if len(lines1) != len(lines2):
return False
for line1, line2 in zip(lines1, lines2):
line1 = self._normalize_result_line_for_compare(line1)
line2 = self._normalize_result_line_for_compare(line2)
if line1 == line2:
continue
matches1 = list(number_pattern.finditer(line1))
matches2 = list(number_pattern.finditer(line2))
if len(matches1) != len(matches2):
return False
cursor1 = 0
cursor2 = 0
for match1, match2 in zip(matches1, matches2):
if line1[cursor1:match1.start()] != line2[cursor2:match2.start()]:
return False
token1 = match1.group(0)
token2 = match2.group(0)
try:
value1 = Decimal(token1)
value2 = Decimal(token2)
except InvalidOperation:
if token1 != token2:
return False
else:
tolerance = self._get_numeric_compare_tolerance(
token1, token2, float_tolerance
)
if abs(value1 - value2) > tolerance:
return False
cursor1 = match1.end()
cursor2 = match2.end()
if line1[cursor1:] != line2[cursor2:]:
return False
return True
def compare_result_files(self, file1, file2, float_tolerance=0.0):
normalized_file1 = None
normalized_file2 = None
try:
# use subprocess.run to execute diff/fc commands
# print(file1, file2)
if platform.system().lower() != "windows":
normalized_file1 = self._normalize_diff_file(file1)
normalized_file2 = self._normalize_diff_file(file2)
cmd = "diff"
tdLog.info(f"cmd: {cmd} -u --color {file1} {file2}")
result = subprocess.run(
[cmd, "-u", "--color", normalized_file1, normalized_file2],
text=True,
capture_output=True,
)
tdLog.info(f"result: {result}")
else:
cmd = "fc"
file1 = os.path.abspath(os.path.normpath(file1))
file2 = os.path.abspath(os.path.normpath(file2))
# Create temporary files, filter empty lines
with tempfile.NamedTemporaryFile(mode='w', delete=False, encoding='utf-8') as tmp1, \
tempfile.NamedTemporaryFile(mode='w', delete=False, encoding='utf-8') as tmp2:
# Copy non-empty lines to temporary file
with open(file1, 'r', encoding='utf-8', errors='ignore') as f:
tmp1.writelines(line for line in f if line.strip())
temp1 = tmp1.name
with open(file2, 'r', encoding='utf-8', errors='ignore') as f:
tmp2.writelines(line for line in f if line.strip())
temp2 = tmp2.name
try:
result = subprocess.run(
[cmd, "/W", temp1, temp2],
text=True,
capture_output=True,
encoding="utf-8",
errors="replace",
)
finally:
os.unlink(temp1)
os.unlink(temp2)
# Windows fc: returncode 0 means files are identical
if result.returncode == 0:
return True
# Result check logic for diff/fc
if result.returncode != 0:
if self._compare_normalized_result_lines(file1, file2):
tdLog.info("Result files matched after output normalization.")
return True
if platform.system().lower() == "windows" and self._compare_file_lines_with_float_tolerance(
file1, file2, float_tolerance
):
tdLog.info(
"Result files matched after Windows output normalization."
if float_tolerance <= 0.0
else "Result files matched after Windows output normalization "
f"with float tolerance {float_tolerance}."
)
return True
tdLog.info(f"{cmd} result.returncode: {result.returncode}")
tdLog.info(f"{cmd} result.stdout: {result.stdout}")
tdLog.info(f"{cmd} result.stderr: {result.stderr}")
return False
if result.stdout:
tdLog.debug(f"Differences between {file1} and {file2}")
tdLog.notice(f"\r\n{result.stdout}")
return False
elif result.stderr:
tdLog.info(f"{cmd} result.stderr: {result.stderr}")
return False
else:
return True
except FileNotFoundError:
tdLog.debug(
"The 'diff' command is not found. Please make sure it's installed and available in your PATH."
)
return False
except Exception as e:
tdLog.debug(f"An error occurred: {e}")
return False
finally:
for normalized_file in (normalized_file1, normalized_file2):
if normalized_file and os.path.exists(normalized_file):
os.remove(normalized_file)
def _normalize_diff_file(self, input_file):
with open(input_file, "r", encoding="utf-8", newline="") as fin:
with tempfile.NamedTemporaryFile(
mode="w", delete=False, encoding="utf-8", newline=""
) as fout:
for line in fin:
if re.fullmatch(r"[ \t]+\|\r?\n", line):
line_ending = "\r\n" if line.endswith("\r\n") else "\n"
fout.write("|" + line_ending)
else:
fout.write(line)
return fout.name
def compare_query_with_result_file(
self, idx, sql, resultFile, test_case, float_tolerance=0.0
):
self.generate_query_result_file(test_case, idx, sql)
if self.compare_result_files(
resultFile, self.query_result_file, float_tolerance=float_tolerance
):
tdLog.info("Test passed: Result files are identical.")
# os.system(f"rm -f {self.query_result_file}")
else:
caller = inspect.getframeinfo(inspect.stack()[1][0])
tdLog.exit(
f"{caller.lineno}(line:{caller.lineno}) failed: expect_file:{resultFile} != reult_file:{self.query_result_file} "
)
def compare_testcase_result(
self, inputfile, expected_file, test_case, float_tolerance=0.0
):
test_reulst_file = self.generate_query_result(inputfile, test_case)
if self.compare_result_files(
expected_file, test_reulst_file, float_tolerance=float_tolerance
):
tdLog.info("Test passed: Result files are identical.")
os.system(f"rm -f {test_reulst_file}")
else:
caller = inspect.getframeinfo(inspect.stack()[1][0])
tdLog.exit(
f"{caller.lineno}(line:{caller.lineno}) failed: sqlfile is {inputfile}, expect_file:{expected_file} != reult_file:{test_reulst_file} "
)
tdLog.exit("Test failed: Result files are different.")
def get_subtable(self, tbname_pre):
tdSql.query(f"show tables")
tbname_list = list(map(lambda x: x[0], tdSql.queryResult))
for tbname in tbname_list:
if tbname_pre in tbname:
return tbname
def get_subtable_wait(self, tbname_pre):
tbname = self.get_subtable(tbname_pre)
latency = 0
while tbname is None:
tbname = self.get_subtable(tbname_pre)
if latency < self.stream_timeout:
latency += 1
time.sleep(1)
else:
return False
return tbname
def get_group_id_from_stb(self, stbname):
tdSql.query(f"select distinct group_id from {stbname}")
cnt = 0
while len(tdSql.queryResult) == 0:
tdSql.query(f"select distinct group_id from {stbname}")
if cnt < self.default_interval:
cnt += 1
time.sleep(1)
else:
return False
return tdSql.queryResult[0][0]
def update_json_file_replica(
self, json_file_path, new_replica_value, output_file_path=None
):
"""
Read a JSON file, update the 'replica' value, and write the result back to a file.
Parameters:
json_file_path (str): The path to the original JSON file.
new_replica_value (int): The new 'replica' value to be set.
output_file_path (str, optional): The path to the output file where the updated JSON will be saved.
If not provided, the original file will be overwritten.
Returns:
None
"""
try:
# Read the JSON file and load its content into a Python dictionary
with open(json_file_path, "r", encoding="utf-8") as file:
data = json.load(file)
# Iterate over each item in the 'databases' list to find 'dbinfo' and update 'replica'
for db in data["databases"]:
if "dbinfo" in db:
db["dbinfo"]["replica"] = new_replica_value
# Convert the updated dictionary back into a JSON string with indentation for readability
updated_json_str = json.dumps(data, indent=4, ensure_ascii=False)
# Write the updated JSON string to a file
if output_file_path:
# If an output file path is provided, write to the new file
with open(output_file_path, "w", encoding="utf-8") as output_file:
output_file.write(updated_json_str)
else:
# Otherwise, overwrite the original file with the updated content
with open(json_file_path, "w", encoding="utf-8") as file:
file.write(updated_json_str)
except json.JSONDecodeError as e:
# Handle JSON decoding error (e.g., if the file is not valid JSON)
print(f"JSON decode error: {e}")
except FileNotFoundError:
# Handle the case where the JSON file is not found at the given path
print(f"File not found: {json_file_path}")
except KeyError as e:
# Handle missing key error (e.g., if 'databases' or 'dbinfo' is not present)
print(f"Key error: {e}")
except Exception as e:
# Handle any other exceptions that may occur
print(f"An error occurred: {e}")
def waitTransactionZeroWithTdsql(self, td_sql, timeout=300):
count = 0
while count < timeout:
result = td_sql.query("show transactions;")
if result == 0:
tdLog.info("transaction count became zero.")
return True
time.sleep(1)
count += 1
tdLog.exit(
f"Timeout after {timeout} seconds waiting for transaction count to become zero."
)
def is_json(msg):
if isinstance(msg, str):
try:
json.loads(msg)
return True
except:
return False
else:
return False
def get_path(tool="taosd"):
return binFile(tool)
# selfPath = os.path.dirname(os.path.realpath(__file__))
# if ("community" in selfPath):
# projPath = selfPath[:selfPath.find("community")]
# else:
# projPath = selfPath[:selfPath.find("tests")]
# paths = []
# for root, dirs, files in os.walk(projPath):
# if ((tool) in files or ("%s.exe"%tool) in files):
# rootRealPath = os.path.dirname(os.path.realpath(root))
# if ("packaging" not in rootRealPath):
# paths.append(os.path.join(root, tool))
# break
# if (len(paths) == 0):
# return ""
# return paths[0]
def dict2toml(in_dict: dict, file: str):
if not isinstance(in_dict, dict):
return ""
with open(file, "w") as f:
toml.dump(in_dict, f)
def is_json(msg):
if isinstance(msg, str):
try:
json.loads(msg)
return True
except:
return False
else:
return False
tdCom = TDCom()