Skip to content

Commit

Permalink
Extract model filter to the server connection
Browse files Browse the repository at this point in the history
- this allows us to use the same filtering code in multiple places
  • Loading branch information
mrdjohnson committed Jun 27, 2024
1 parent 1ece694 commit 128124b
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 18 deletions.
7 changes: 7 additions & 0 deletions src/features/connections/servers/A1111ServerConnection.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ class A1111ServerConnection extends ServerConnection<IA1111Model> {

return camelcaseKeys<IA1111Model[]>(response.data).map(LanguageModel.fromIA1111Model)
}

override modelFilter(model: A1111LanguageModel, filterText: string) {
return (
model.modelName.toLowerCase().includes(filterText.toLowerCase()) ||
model.label.toLowerCase().includes(filterText.toLowerCase())
)
}
}

export default A1111ServerConnection
7 changes: 7 additions & 0 deletions src/features/connections/servers/OpenAiServerConnection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ class OpenAiServerConnection extends ServerConnection<IOpenAiModel> {

return trueResponse.map(model => LanguageModel.fromIOpenAiModel(model))
}

override modelFilter(model: OpenAiLanguageModel, filterText: string) {
return (
model.modelName.toLowerCase().includes(filterText.toLowerCase()) ||
model.ownedBy.toLowerCase().includes(filterText.toLowerCase())
)
}
}

export default OpenAiServerConnection
4 changes: 4 additions & 0 deletions src/features/connections/servers/ServerConnection.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ abstract class ServerConnection<
) => ReactNode

abstract api: BaseApi

modelFilter(model: LanguageModelType<BaseModelType>, filterText: string) {
return model.modelName.toLowerCase().includes(filterText.toLowerCase())
}
}

export default ServerConnection
7 changes: 1 addition & 6 deletions src/features/settings/panels/model/A1111ModelPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,7 @@ const A1111ModelPanel = observer(({ connection }: { connection: A1111ServerConne
onItemSelected={model =>
connectionModelStore.dataStore.setSelectedModel(model, connection.id)
}
itemFilter={(model, filterText) => {
return (
model.modelName.toLowerCase().includes(filterText.toLowerCase()) ||
model.label.toLowerCase().includes(filterText.toLowerCase())
)
}}
itemFilter={connection.modelFilter}
primarySortTypeLabel={connection.primaryHeader}
renderRow={renderRow}
getIsItemSelected={model =>
Expand Down
4 changes: 1 addition & 3 deletions src/features/settings/panels/model/LmsModelPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ const LmsModelPanel = observer(({ connection }: { connection: LmsServerConnectio
onItemSelected={model =>
connectionModelStore.dataStore.setSelectedModel(model, connection.id)
}
itemFilter={(model, filterText) => {
return model.modelName.toLowerCase().includes(filterText.toLowerCase())
}}
itemFilter={connection.modelFilter}
primarySortTypeLabel={connection.primaryHeader}
renderRow={renderRow}
getIsItemSelected={model =>
Expand Down
4 changes: 1 addition & 3 deletions src/features/settings/panels/model/OllamaModelPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ const OllamaModelPanelTable = observer(
items={connection.models}
sortTypes={connection.modelTableHeaders}
primarySortTypeLabel="name"
itemFilter={(model: IOllamaModel, filterText: string) =>
model.name.toLowerCase().includes(filterText.toLowerCase())
}
itemFilter={connection.modelFilter}
renderRow={renderRow}
getItemKey={model => model.name}
onItemSelected={model =>
Expand Down
7 changes: 1 addition & 6 deletions src/features/settings/panels/model/OpenAiModelPanel.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,7 @@ const OpenAiModelPanel = observer(({ connection }: { connection: OpenAiServerCon
onItemSelected={model =>
connectionModelStore.dataStore.setSelectedModel(model, connection.id)
}
itemFilter={(model, filterText) => {
return (
model.modelName.toLowerCase().includes(filterText.toLowerCase()) ||
model.ownedBy.toLowerCase().includes(filterText.toLowerCase())
)
}}
itemFilter={connection.modelFilter}
primarySortTypeLabel={connection.primaryHeader}
renderRow={renderRow}
getIsItemSelected={model =>
Expand Down

0 comments on commit 128124b

Please sign in to comment.