mirror of
https://github.com/immich-app/immich
synced 2025-11-14 17:36:12 +00:00
feat(ml): better multilingual search with nllb models (#13567)
This commit is contained in:
parent
838a8dd9a6
commit
6789c2ac19
16 changed files with 301 additions and 18 deletions
|
|
@ -191,6 +191,11 @@ export class SmartSearchDto extends BaseSearchDto {
|
|||
@IsNotEmpty()
|
||||
query!: string;
|
||||
|
||||
@IsString()
|
||||
@IsNotEmpty()
|
||||
@Optional()
|
||||
language?: string;
|
||||
|
||||
@IsInt()
|
||||
@Min(1)
|
||||
@Type(() => Number)
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ export interface Face {
|
|||
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
|
||||
export type DetectedFaces = { faces: Face[] } & VisualResponse;
|
||||
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
|
||||
export type TextEncodingOptions = ModelOptions & { language?: string };
|
||||
|
||||
@Injectable()
|
||||
export class MachineLearningRepository {
|
||||
|
|
@ -170,8 +171,8 @@ export class MachineLearningRepository {
|
|||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
||||
async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
|
||||
async encodeText(urls: string[], text: string, { language, modelName }: TextEncodingOptions) {
|
||||
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, options: { language } } } };
|
||||
const response = await this.predict<ClipTextualResponse>(urls, { text }, request);
|
||||
return response[ModelTask.SEARCH];
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import { BadRequestException } from '@nestjs/common';
|
||||
import { mapAsset } from 'src/dtos/asset-response.dto';
|
||||
import { SearchSuggestionType } from 'src/dtos/search.dto';
|
||||
import { SearchService } from 'src/services/search.service';
|
||||
|
|
@ -15,6 +16,7 @@ describe(SearchService.name, () => {
|
|||
|
||||
beforeEach(() => {
|
||||
({ sut, mocks } = newTestService(SearchService));
|
||||
mocks.partner.getAll.mockResolvedValue([]);
|
||||
});
|
||||
|
||||
it('should work', () => {
|
||||
|
|
@ -155,4 +157,83 @@ describe(SearchService.name, () => {
|
|||
expect(mocks.search.getCameraModels).toHaveBeenCalledWith([authStub.user1.user.id], expect.anything());
|
||||
});
|
||||
});
|
||||
|
||||
describe('searchSmart', () => {
|
||||
beforeEach(() => {
|
||||
mocks.search.searchSmart.mockResolvedValue({ hasNextPage: false, items: [] });
|
||||
mocks.machineLearning.encodeText.mockResolvedValue('[1, 2, 3]');
|
||||
});
|
||||
|
||||
it('should raise a BadRequestException if machine learning is disabled', async () => {
|
||||
mocks.systemMetadata.get.mockResolvedValue({
|
||||
machineLearning: { enabled: false },
|
||||
});
|
||||
|
||||
await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError(
|
||||
new BadRequestException('Smart search is not enabled'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should raise a BadRequestException if smart search is disabled', async () => {
|
||||
mocks.systemMetadata.get.mockResolvedValue({
|
||||
machineLearning: { clip: { enabled: false } },
|
||||
});
|
||||
|
||||
await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError(
|
||||
new BadRequestException('Smart search is not enabled'),
|
||||
);
|
||||
});
|
||||
|
||||
it('should work', async () => {
|
||||
await sut.searchSmart(authStub.user1, { query: 'test' });
|
||||
|
||||
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
|
||||
[expect.any(String)],
|
||||
'test',
|
||||
expect.objectContaining({ modelName: expect.any(String) }),
|
||||
);
|
||||
expect(mocks.search.searchSmart).toHaveBeenCalledWith(
|
||||
{ page: 1, size: 100 },
|
||||
{ query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] },
|
||||
);
|
||||
});
|
||||
|
||||
it('should consider page and size parameters', async () => {
|
||||
await sut.searchSmart(authStub.user1, { query: 'test', page: 2, size: 50 });
|
||||
|
||||
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
|
||||
[expect.any(String)],
|
||||
'test',
|
||||
expect.objectContaining({ modelName: expect.any(String) }),
|
||||
);
|
||||
expect(mocks.search.searchSmart).toHaveBeenCalledWith(
|
||||
{ page: 2, size: 50 },
|
||||
expect.objectContaining({ query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use clip model specified in config', async () => {
|
||||
mocks.systemMetadata.get.mockResolvedValue({
|
||||
machineLearning: { clip: { modelName: 'ViT-B-16-SigLIP__webli' } },
|
||||
});
|
||||
|
||||
await sut.searchSmart(authStub.user1, { query: 'test' });
|
||||
|
||||
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
|
||||
[expect.any(String)],
|
||||
'test',
|
||||
expect.objectContaining({ modelName: 'ViT-B-16-SigLIP__webli' }),
|
||||
);
|
||||
});
|
||||
|
||||
it('should use language specified in request', async () => {
|
||||
await sut.searchSmart(authStub.user1, { query: 'test', language: 'de' });
|
||||
|
||||
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
|
||||
[expect.any(String)],
|
||||
'test',
|
||||
expect.objectContaining({ language: 'de' }),
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
|
|
|||
|
|
@ -78,12 +78,10 @@ export class SearchService extends BaseService {
|
|||
}
|
||||
|
||||
const userIds = await this.getUserIdsToSearch(auth);
|
||||
|
||||
const embedding = await this.machineLearningRepository.encodeText(
|
||||
machineLearning.urls,
|
||||
dto.query,
|
||||
machineLearning.clip,
|
||||
);
|
||||
const embedding = await this.machineLearningRepository.encodeText(machineLearning.urls, dto.query, {
|
||||
modelName: machineLearning.clip.modelName,
|
||||
language: dto.language,
|
||||
});
|
||||
const page = dto.page ?? 1;
|
||||
const size = dto.size || 100;
|
||||
const { hasNextPage, items } = await this.searchRepository.searchSmart(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue