diff --git a/.vscode/launch.json b/.vscode/launch.json index 9cf273a..faf1ffb 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -9,6 +9,6 @@ "program": ".", "type": "go", "cwd": "${cwd}", - "args": ["generate", "--module", "github.com/vloldik/dbml-gen", "-i", "./test/test.dbml", "-o", "./test/output", "--backend", "gorm"] + "args": ["generate", "--module", "github.com/vloldik/dbml-gen", "-i", "./example/test.dbml", "-o", "./example/output", "--backend", "gorm"] }] } \ No newline at end of file diff --git a/example/output/ecommerce/merchant.go b/example/output/ecommerce/merchant.go index dd137f6..ad206fc 100644 --- a/example/output/ecommerce/merchant.go +++ b/example/output/ecommerce/merchant.go @@ -12,7 +12,7 @@ type Merchant struct { CreatedAt *string `gorm:"column:created_at"` AdminID *int `gorm:"column:admin_id"` Country *public.Country `gorm:"foreignKey:Code;References:CountryCode"` - User *public.User `gorm:"foreignKey:ID;References:AdminID"` + Admin *public.User `gorm:"foreignKey:ID;References:AdminID"` } func (Merchant) TableName() string { diff --git a/example/output/migrates/migrate.go b/example/output/migrates/migrate.go index 958447f..b71dfa2 100644 --- a/example/output/migrates/migrate.go +++ b/example/output/migrates/migrate.go @@ -7,5 +7,5 @@ import ( ) func MigrateAll(db *gorm.DB) error { - return db.AutoMigrate(&ecommerce.Order{}, &ecommerce.Product{}, &ecommerce.MerchantPeriod{}, &public.User{}, &public.Country{}, &ecommerce.ProductTag{}, &ecommerce.Merchant{}, &ecommerce.OrderItem{}) + return db.AutoMigrate(&ecommerce.ProductTag{}, &ecommerce.MerchantPeriod{}, &ecommerce.Merchant{}, &public.Country{}, &ecommerce.Order{}, &public.User{}, &ecommerce.OrderItem{}, &ecommerce.Product{}) } diff --git a/internal/generator/gormgen/struct_generator.go b/internal/generator/gormgen/struct_generator.go index 33c2c95..348d161 100644 --- a/internal/generator/gormgen/struct_generator.go +++ b/internal/generator/gormgen/struct_generator.go @@ -105,7 +105,7 @@ func (sg *GORMStructGenerator) createFieldRelation(relation *models.Relationship // True if we want to use []list isX_ToMany := relation.RelationType == models.ManyToMany || relation.RelationType == models.OneToMany qual := sg.getStructQualifier(relation.ToTable) - createdFieldName := sg.createRelatedFieldName(relation.ToField, relation.ToTable, isX_ToMany) + createdFieldName := sg.createRelatedFieldName(relation.FromField, relation.ToTable, isX_ToMany) createdField := jen.Id(createdFieldName) if isX_ToMany { @@ -186,11 +186,14 @@ func (sg *GORMStructGenerator) getStructQualifier(table *models.Table) string { return qual } -func (sg *GORMStructGenerator) createRelatedFieldName(field *models.Field, table *models.Table, isX_toMany bool) string { - relatedFieldName := field.DisplayName() +func (sg *GORMStructGenerator) createRelatedFieldName(fromField *models.Field, toTable *models.Table, isX_toMany bool) string { + relatedFieldName := fromField.DisplayName() relatedFieldName, found := strings.CutSuffix(relatedFieldName, "Id") + if !found { + relatedFieldName, found = strings.CutSuffix(relatedFieldName, "ID") + } if len(relatedFieldName) < 2 || !found { - relatedFieldName = table.DisplayName() + relatedFieldName = toTable.DisplayName() } if sg.structFields.hasName(relatedFieldName) { relatedFieldName = "Related" + relatedFieldName