feat: view similar photos (#21108)

* Enable filteing by example

* Drop `@GenerateSql` for `getEmbedding`?

* Improve error message

* PR Feedback

* Sort en.json

* Add SQL

* Fix lint

* Drop test that is no longer valid

* Fix i18n file sorting

* Fix TS error

* Add a `requireAccess` before pulling the embedding

* Fix decorators

* Run `make open-api`

---------

Co-authored-by: Alex <alex.tran1502@gmail.com>
This commit is contained in:
Arthur Normand 2025-09-04 10:22:09 -04:00 committed by GitHub
parent bf6211776f
commit 37a79292c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 105 additions and 29 deletions

View file

@ -128,12 +128,6 @@ describe(SearchController.name, () => {
await request(ctx.getHttpServer()).post('/search/smart');
expect(ctx.authenticate).toHaveBeenCalled();
});
it('should require a query', async () => {
const { status, body } = await request(ctx.getHttpServer()).post('/search/smart').send({});
expect(status).toBe(400);
expect(body).toEqual(errorDto.badRequest(['query should not be empty', 'query must be a string']));
});
});
describe('GET /search/explore', () => {

View file

@ -199,7 +199,12 @@ export class StatisticsSearchDto extends BaseSearchDto {
export class SmartSearchDto extends BaseSearchWithResultsDto {
@IsString()
@IsNotEmpty()
query!: string;
@Optional()
query?: string;
@ValidateUUID({ optional: true })
@Optional()
queryAssetId?: string;
@IsString()
@IsNotEmpty()

View file

@ -123,6 +123,14 @@ offset
$8
commit
-- SearchRepository.getEmbedding
select
*
from
"smart_search"
where
"assetId" = $1
-- SearchRepository.searchFaces
begin
set

View file

@ -293,6 +293,13 @@ export class SearchRepository {
});
}
@GenerateSql({
params: [DummyValue.UUID],
})
async getEmbedding(assetId: string) {
return this.db.selectFrom('smart_search').selectAll().where('assetId', '=', assetId).executeTakeFirst();
}
@GenerateSql({
params: [
{

View file

@ -18,7 +18,7 @@ import {
SmartSearchDto,
StatisticsSearchDto,
} from 'src/dtos/search.dto';
import { AssetOrder, AssetVisibility } from 'src/enum';
import { AssetOrder, AssetVisibility, Permission } from 'src/enum';
import { BaseService } from 'src/services/base.service';
import { requireElevatedPermission } from 'src/utils/access';
import { getMyPartnerIds } from 'src/utils/asset.util';
@ -113,14 +113,27 @@ export class SearchService extends BaseService {
}
const userIds = this.getUserIdsToSearch(auth);
const key = machineLearning.clip.modelName + dto.query + dto.language;
let embedding = this.embeddingCache.get(key);
if (!embedding) {
embedding = await this.machineLearningRepository.encodeText(machineLearning.urls, dto.query, {
modelName: machineLearning.clip.modelName,
language: dto.language,
});
this.embeddingCache.set(key, embedding);
let embedding;
if (dto.query) {
const key = machineLearning.clip.modelName + dto.query + dto.language;
embedding = this.embeddingCache.get(key);
if (!embedding) {
embedding = await this.machineLearningRepository.encodeText(machineLearning.urls, dto.query, {
modelName: machineLearning.clip.modelName,
language: dto.language,
});
this.embeddingCache.set(key, embedding);
}
} else if (dto.queryAssetId) {
await this.requireAccess({ auth, permission: Permission.AssetRead, ids: [dto.queryAssetId] });
const getEmbeddingResponse = await this.searchRepository.getEmbedding(dto.queryAssetId);
const assetEmbedding = getEmbeddingResponse?.embedding;
if (!assetEmbedding) {
throw new BadRequestException(`Asset ${dto.queryAssetId} has no embedding`);
}
embedding = assetEmbedding;
} else {
throw new BadRequestException('Either `query` or `queryAssetId` must be set');
}
const page = dto.page ?? 1;
const size = dto.size || 100;