feat(ml): better multilingual search with nllb models (#13567)

This commit is contained in:
Mert 2025-03-31 11:06:57 -04:00 committed by GitHub
parent 838a8dd9a6
commit 6789c2ac19
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 301 additions and 18 deletions

View file

@ -191,6 +191,11 @@ export class SmartSearchDto extends BaseSearchDto {
@IsNotEmpty()
query!: string;
@IsString()
@IsNotEmpty()
@Optional()
language?: string;
@IsInt()
@Min(1)
@Type(() => Number)

View file

@ -53,6 +53,7 @@ export interface Face {
export type FacialRecognitionResponse = { [ModelTask.FACIAL_RECOGNITION]: Face[] } & VisualResponse;
export type DetectedFaces = { faces: Face[] } & VisualResponse;
export type MachineLearningRequest = ClipVisualRequest | ClipTextualRequest | FacialRecognitionRequest;
export type TextEncodingOptions = ModelOptions & { language?: string };
@Injectable()
export class MachineLearningRepository {
@ -170,8 +171,8 @@ export class MachineLearningRepository {
return response[ModelTask.SEARCH];
}
async encodeText(urls: string[], text: string, { modelName }: CLIPConfig) {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName } } };
async encodeText(urls: string[], text: string, { language, modelName }: TextEncodingOptions) {
const request = { [ModelTask.SEARCH]: { [ModelType.TEXTUAL]: { modelName, options: { language } } } };
const response = await this.predict<ClipTextualResponse>(urls, { text }, request);
return response[ModelTask.SEARCH];
}

View file

@ -1,3 +1,4 @@
import { BadRequestException } from '@nestjs/common';
import { mapAsset } from 'src/dtos/asset-response.dto';
import { SearchSuggestionType } from 'src/dtos/search.dto';
import { SearchService } from 'src/services/search.service';
@ -15,6 +16,7 @@ describe(SearchService.name, () => {
beforeEach(() => {
({ sut, mocks } = newTestService(SearchService));
mocks.partner.getAll.mockResolvedValue([]);
});
it('should work', () => {
@ -155,4 +157,83 @@ describe(SearchService.name, () => {
expect(mocks.search.getCameraModels).toHaveBeenCalledWith([authStub.user1.user.id], expect.anything());
});
});
describe('searchSmart', () => {
beforeEach(() => {
mocks.search.searchSmart.mockResolvedValue({ hasNextPage: false, items: [] });
mocks.machineLearning.encodeText.mockResolvedValue('[1, 2, 3]');
});
it('should raise a BadRequestException if machine learning is disabled', async () => {
mocks.systemMetadata.get.mockResolvedValue({
machineLearning: { enabled: false },
});
await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError(
new BadRequestException('Smart search is not enabled'),
);
});
it('should raise a BadRequestException if smart search is disabled', async () => {
mocks.systemMetadata.get.mockResolvedValue({
machineLearning: { clip: { enabled: false } },
});
await expect(sut.searchSmart(authStub.user1, { query: 'test' })).rejects.toThrowError(
new BadRequestException('Smart search is not enabled'),
);
});
it('should work', async () => {
await sut.searchSmart(authStub.user1, { query: 'test' });
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
[expect.any(String)],
'test',
expect.objectContaining({ modelName: expect.any(String) }),
);
expect(mocks.search.searchSmart).toHaveBeenCalledWith(
{ page: 1, size: 100 },
{ query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] },
);
});
it('should consider page and size parameters', async () => {
await sut.searchSmart(authStub.user1, { query: 'test', page: 2, size: 50 });
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
[expect.any(String)],
'test',
expect.objectContaining({ modelName: expect.any(String) }),
);
expect(mocks.search.searchSmart).toHaveBeenCalledWith(
{ page: 2, size: 50 },
expect.objectContaining({ query: 'test', embedding: '[1, 2, 3]', userIds: [authStub.user1.user.id] }),
);
});
it('should use clip model specified in config', async () => {
mocks.systemMetadata.get.mockResolvedValue({
machineLearning: { clip: { modelName: 'ViT-B-16-SigLIP__webli' } },
});
await sut.searchSmart(authStub.user1, { query: 'test' });
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
[expect.any(String)],
'test',
expect.objectContaining({ modelName: 'ViT-B-16-SigLIP__webli' }),
);
});
it('should use language specified in request', async () => {
await sut.searchSmart(authStub.user1, { query: 'test', language: 'de' });
expect(mocks.machineLearning.encodeText).toHaveBeenCalledWith(
[expect.any(String)],
'test',
expect.objectContaining({ language: 'de' }),
);
});
});
});

View file

@ -78,12 +78,10 @@ export class SearchService extends BaseService {
}
const userIds = await this.getUserIdsToSearch(auth);
const embedding = await this.machineLearningRepository.encodeText(
machineLearning.urls,
dto.query,
machineLearning.clip,
);
const embedding = await this.machineLearningRepository.encodeText(machineLearning.urls, dto.query, {
modelName: machineLearning.clip.modelName,
language: dto.language,
});
const page = dto.page ?? 1;
const size = dto.size || 100;
const { hasNextPage, items } = await this.searchRepository.searchSmart(