fix(cli): dynamically load pg module in "db pull" (#2421)

This commit is contained in:
Yiming Cao 2026-02-27 22:48:52 -05:00 committed by GitHub
parent 3336505ed6
commit 7de2af2be2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -1,10 +1,9 @@
import type { ZModelServices } from '@zenstackhq/language';
import type { Attribute, BuiltinType, Enum, Expression } from '@zenstackhq/language/ast';
import { AstFactory, DataFieldAttributeFactory, ExpressionBuilder } from '@zenstackhq/language/factory';
import { Client } from 'pg';
import { CliError } from '../../../cli-error';
import { getAttributeRef, getDbName, getFunctionRef, normalizeDecimalDefault, normalizeFloatDefault } from '../utils';
import type { IntrospectedEnum, IntrospectedSchema, IntrospectedTable, IntrospectionProvider } from './provider';
import type { ZModelServices } from '@zenstackhq/language';
import { CliError } from '../../../cli-error';
/**
* Maps PostgreSQL internal type names to their standard SQL names for comparison.
@ -110,8 +109,8 @@ const pgTypnameToZenStackNativeType: Record<string, string> = {
export const postgresql: IntrospectionProvider = {
isSupportedFeature(feature) {
const supportedFeatures = ['Schema', 'NativeEnum'];
return supportedFeatures.includes(feature);
const supportedFeatures = ['Schema', 'NativeEnum'];
return supportedFeatures.includes(feature);
},
getBuiltinType(type) {
const t = (type || '').toLowerCase();
@ -176,7 +175,11 @@ export const postgresql: IntrospectionProvider = {
return { type: 'Unsupported' as const, isArray };
}
},
async introspect(connectionString: string, options: { schemas: string[]; modelCasing: 'pascal' | 'camel' | 'snake' | 'none' }): Promise<IntrospectedSchema> {
async introspect(
connectionString: string,
options: { schemas: string[]; modelCasing: 'pascal' | 'camel' | 'snake' | 'none' },
): Promise<IntrospectedSchema> {
const { Client } = await import('pg');
const client = new Client({ connectionString });
await client.connect();
@ -233,7 +236,7 @@ export const postgresql: IntrospectionProvider = {
}
}
// Fall through to typeCastingConvert if datatype_name lookup fails
return typeCastingConvert({defaultValue,enums,val,services});
return typeCastingConvert({ defaultValue, enums, val, services });
}
switch (fieldType) {
@ -243,7 +246,7 @@ export const postgresql: IntrospectionProvider = {
}
if (val.includes('::')) {
return typeCastingConvert({defaultValue,enums,val,services});
return typeCastingConvert({ defaultValue, enums, val, services });
}
// Fallback to string literal for other DateTime defaults
@ -256,19 +259,19 @@ export const postgresql: IntrospectionProvider = {
}
if (val.includes('::')) {
return typeCastingConvert({defaultValue,enums,val,services});
return typeCastingConvert({ defaultValue, enums, val, services });
}
return (ab) => ab.NumberLiteral.setValue(val);
case 'Float':
if (val.includes('::')) {
return typeCastingConvert({defaultValue,enums,val,services});
return typeCastingConvert({ defaultValue, enums, val, services });
}
return normalizeFloatDefault(val);
case 'Decimal':
if (val.includes('::')) {
return typeCastingConvert({defaultValue,enums,val,services});
return typeCastingConvert({ defaultValue, enums, val, services });
}
return normalizeDecimalDefault(val);
@ -277,7 +280,7 @@ export const postgresql: IntrospectionProvider = {
case 'String':
if (val.includes('::')) {
return typeCastingConvert({defaultValue,enums,val,services});
return typeCastingConvert({ defaultValue, enums, val, services });
}
if (val.startsWith("'") && val.endsWith("'")) {
@ -286,12 +289,12 @@ export const postgresql: IntrospectionProvider = {
return (ab) => ab.StringLiteral.setValue(val);
case 'Json':
if (val.includes('::')) {
return typeCastingConvert({defaultValue,enums,val,services});
return typeCastingConvert({ defaultValue, enums, val, services });
}
return (ab) => ab.StringLiteral.setValue(val);
case 'Bytes':
if (val.includes('::')) {
return typeCastingConvert({defaultValue,enums,val,services});
return typeCastingConvert({ defaultValue, enums, val, services });
}
return (ab) => ab.StringLiteral.setValue(val);
}
@ -303,7 +306,9 @@ export const postgresql: IntrospectionProvider = {
);
}
console.warn(`Unsupported default value type: "${defaultValue}" for field type "${fieldType}". Skipping default value.`);
console.warn(
`Unsupported default value type: "${defaultValue}" for field type "${fieldType}". Skipping default value.`,
);
return null;
},
@ -311,7 +316,10 @@ export const postgresql: IntrospectionProvider = {
const factories: DataFieldAttributeFactory[] = [];
// Add @updatedAt for DateTime fields named updatedAt or updated_at
if (fieldType === 'DateTime' && (fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')) {
if (
fieldType === 'DateTime' &&
(fieldName.toLowerCase() === 'updatedat' || fieldName.toLowerCase() === 'updated_at')
) {
factories.push(new DataFieldAttributeFactory().setDecl(getAttributeRef('@updatedAt', services)));
}
@ -338,8 +346,7 @@ export const postgresql: IntrospectionProvider = {
dbAttr &&
defaultDatabaseType &&
(defaultDatabaseType.type !== normalizedDatatype ||
(defaultDatabaseType.precision &&
defaultDatabaseType.precision !== (length ?? precision)))
(defaultDatabaseType.precision && defaultDatabaseType.precision !== (length ?? precision)))
) {
const dbAttrFactory = new DataFieldAttributeFactory().setDecl(dbAttr);
// Only add length/precision if it's meaningful (not the standard bit width for the type)
@ -628,7 +635,17 @@ WHERE
ORDER BY "ns"."nspname", "cls"."relname" ASC;
`;
function typeCastingConvert({defaultValue, enums, val, services}:{val: string, enums: Enum[], defaultValue:string, services:ZModelServices}): ((builder: ExpressionBuilder) => AstFactory<Expression>) | null {
function typeCastingConvert({
defaultValue,
enums,
val,
services,
}: {
val: string;
enums: Enum[];
defaultValue: string;
services: ZModelServices;
}): ((builder: ExpressionBuilder) => AstFactory<Expression>) | null {
const [value, type] = val
.replace(/'/g, '')
.split('::')
@ -653,9 +670,7 @@ function typeCastingConvert({defaultValue, enums, val, services}:{val: string, e
}
const enumField = enumDef.fields.find((v) => getDbName(v) === value);
if (!enumField) {
throw new CliError(
`Enum value ${value} not found in enum ${type} for default value ${defaultValue}`,
);
throw new CliError(`Enum value ${value} not found in enum ${type} for default value ${defaultValue}`);
}
return (ab) => ab.ReferenceExpr.setTarget(enumField);
}