feat(server): CLIP search integration (#1939)

This commit is contained in:
Alex 2023-03-18 08:44:42 -05:00 committed by GitHub
parent 0d436db3ea
commit f56eaae019
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
46 changed files with 673 additions and 773 deletions

View file

@ -3,6 +3,7 @@ import { AlbumEntity } from '@app/infra/db/entities';
export const IAlbumRepository = 'IAlbumRepository';
export interface IAlbumRepository {
getByIds(ids: string[]): Promise<AlbumEntity[]>;
deleteAll(userId: string): Promise<void>;
getAll(): Promise<AlbumEntity[]>;
save(album: Partial<AlbumEntity>): Promise<AlbumEntity>;

View file

@ -11,7 +11,10 @@ export class AssetCore {
async save(asset: Partial<AssetEntity>) {
const _asset = await this.assetRepository.save(asset);
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { asset: _asset } });
await this.jobRepository.queue({
name: JobName.SEARCH_INDEX_ASSET,
data: { ids: [_asset.id] },
});
return _asset;
}

View file

@ -7,6 +7,7 @@ export interface AssetSearchOptions {
export const IAssetRepository = 'IAssetRepository';
export interface IAssetRepository {
getByIds(ids: string[]): Promise<AssetEntity[]>;
deleteAll(ownerId: string): Promise<void>;
getAll(options?: AssetSearchOptions): Promise<AssetEntity[]>;
save(asset: Partial<AssetEntity>): Promise<AssetEntity>;

View file

@ -54,7 +54,7 @@ describe(AssetService.name, () => {
expect(assetMock.save).toHaveBeenCalledWith(assetEntityStub.image);
expect(jobMock.queue).toHaveBeenCalledWith({
name: JobName.SEARCH_INDEX_ASSET,
data: { asset: assetEntityStub.image },
data: { ids: [assetEntityStub.image.id] },
});
});
});

View file

@ -29,4 +29,5 @@ export enum JobName {
SEARCH_INDEX_ALBUM = 'search-index-album',
SEARCH_REMOVE_ALBUM = 'search-remove-album',
SEARCH_REMOVE_ASSET = 'search-remove-asset',
ENCODE_CLIP = 'clip-encode',
}

View file

@ -8,15 +8,15 @@ export interface IAssetJob {
asset: AssetEntity;
}
export interface IBulkEntityJob {
ids: string[];
}
export interface IAssetUploadedJob {
asset: AssetEntity;
fileName: string;
}
export interface IDeleteJob {
id: string;
}
export interface IDeleteFilesJob {
files: Array<string | null | undefined>;
}

View file

@ -1,10 +1,9 @@
import { JobName, QueueName } from './job.constants';
import {
IAlbumJob,
IAssetJob,
IAssetUploadedJob,
IBulkEntityJob,
IDeleteFilesJob,
IDeleteJob,
IReverseGeocodingJob,
IUserDeletionJob,
} from './job.interface';
@ -31,13 +30,14 @@ export type JobItem =
| { name: JobName.EXTRACT_VIDEO_METADATA; data: IAssetUploadedJob }
| { name: JobName.OBJECT_DETECTION; data: IAssetJob }
| { name: JobName.IMAGE_TAGGING; data: IAssetJob }
| { name: JobName.ENCODE_CLIP; data: IAssetJob }
| { name: JobName.DELETE_FILES; data: IDeleteFilesJob }
| { name: JobName.SEARCH_INDEX_ASSETS }
| { name: JobName.SEARCH_INDEX_ASSET; data: IAssetJob }
| { name: JobName.SEARCH_INDEX_ASSET; data: IBulkEntityJob }
| { name: JobName.SEARCH_INDEX_ALBUMS }
| { name: JobName.SEARCH_INDEX_ALBUM; data: IAlbumJob }
| { name: JobName.SEARCH_REMOVE_ASSET; data: IDeleteJob }
| { name: JobName.SEARCH_REMOVE_ALBUM; data: IDeleteJob };
| { name: JobName.SEARCH_INDEX_ALBUM; data: IBulkEntityJob }
| { name: JobName.SEARCH_REMOVE_ASSET; data: IBulkEntityJob }
| { name: JobName.SEARCH_REMOVE_ALBUM; data: IBulkEntityJob };
export const IJobRepository = 'IJobRepository';

View file

@ -54,6 +54,7 @@ export class MediaService {
await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: { asset } });
await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } });
await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } });
await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } });
this.communicationRepository.send(CommunicationEvent.UPLOAD_SUCCESS, asset.ownerId, mapAsset(asset));
}
@ -72,6 +73,7 @@ export class MediaService {
await this.jobRepository.queue({ name: JobName.GENERATE_WEBP_THUMBNAIL, data: { asset } });
await this.jobRepository.queue({ name: JobName.IMAGE_TAGGING, data: { asset } });
await this.jobRepository.queue({ name: JobName.OBJECT_DETECTION, data: { asset } });
await this.jobRepository.queue({ name: JobName.ENCODE_CLIP, data: { asset } });
this.communicationRepository.send(CommunicationEvent.UPLOAD_SUCCESS, asset.ownerId, mapAsset(asset));
} catch (error: any) {

View file

@ -4,11 +4,21 @@ import { IsArray, IsBoolean, IsEnum, IsNotEmpty, IsOptional, IsString } from 'cl
import { toBoolean } from '../../../../../apps/immich/src/utils/transform.util';
export class SearchDto {
@IsString()
@IsNotEmpty()
@IsOptional()
q?: string;
@IsString()
@IsNotEmpty()
@IsOptional()
query?: string;
@IsBoolean()
@IsOptional()
@Transform(toBoolean)
clip?: boolean;
@IsEnum(AssetType)
@IsOptional()
type?: AssetType;

View file

@ -5,6 +5,11 @@ export enum SearchCollection {
ALBUMS = 'albums',
}
export enum SearchStrategy {
CLIP = 'CLIP',
TEXT = 'TEXT',
}
export interface SearchFilter {
id?: string;
userId: string;
@ -19,6 +24,7 @@ export interface SearchFilter {
tags?: string[];
recent?: boolean;
motion?: boolean;
debug?: boolean;
}
export interface SearchResult<T> {
@ -57,16 +63,15 @@ export interface ISearchRepository {
setup(): Promise<void>;
checkMigrationStatus(): Promise<SearchCollectionIndexStatus>;
index(collection: SearchCollection.ASSETS, item: AssetEntity): Promise<void>;
index(collection: SearchCollection.ALBUMS, item: AlbumEntity): Promise<void>;
importAlbums(items: AlbumEntity[], done: boolean): Promise<void>;
importAssets(items: AssetEntity[], done: boolean): Promise<void>;
delete(collection: SearchCollection, id: string): Promise<void>;
deleteAlbums(ids: string[]): Promise<void>;
deleteAssets(ids: string[]): Promise<void>;
import(collection: SearchCollection.ASSETS, items: AssetEntity[], done: boolean): Promise<void>;
import(collection: SearchCollection.ALBUMS, items: AlbumEntity[], done: boolean): Promise<void>;
search(collection: SearchCollection.ASSETS, query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>>;
search(collection: SearchCollection.ALBUMS, query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>>;
searchAlbums(query: string, filters: SearchFilter): Promise<SearchResult<AlbumEntity>>;
searchAssets(query: string, filters: SearchFilter): Promise<SearchResult<AssetEntity>>;
vectorSearch(query: number[], filters: SearchFilter): Promise<SearchResult<AssetEntity>>;
explore(userId: string): Promise<SearchExploreItem<AssetEntity>[]>;
}

View file

@ -4,25 +4,32 @@ import { plainToInstance } from 'class-transformer';
import {
albumStub,
assetEntityStub,
asyncTick,
authStub,
newAlbumRepositoryMock,
newAssetRepositoryMock,
newJobRepositoryMock,
newMachineLearningRepositoryMock,
newSearchRepositoryMock,
searchStub,
} from '../../test';
import { IAlbumRepository } from '../album/album.repository';
import { IAssetRepository } from '../asset/asset.repository';
import { JobName } from '../job';
import { IJobRepository } from '../job/job.repository';
import { IMachineLearningRepository } from '../smart-info';
import { SearchDto } from './dto';
import { ISearchRepository } from './search.repository';
import { SearchService } from './search.service';
jest.useFakeTimers();
describe(SearchService.name, () => {
let sut: SearchService;
let albumMock: jest.Mocked<IAlbumRepository>;
let assetMock: jest.Mocked<IAssetRepository>;
let jobMock: jest.Mocked<IJobRepository>;
let machineMock: jest.Mocked<IMachineLearningRepository>;
let searchMock: jest.Mocked<ISearchRepository>;
let configMock: jest.Mocked<ConfigService>;
@ -30,10 +37,15 @@ describe(SearchService.name, () => {
albumMock = newAlbumRepositoryMock();
assetMock = newAssetRepositoryMock();
jobMock = newJobRepositoryMock();
machineMock = newMachineLearningRepositoryMock();
searchMock = newSearchRepositoryMock();
configMock = { get: jest.fn() } as unknown as jest.Mocked<ConfigService>;
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
});
afterEach(() => {
sut.teardown();
});
it('should work', () => {
@ -69,7 +81,7 @@ describe(SearchService.name, () => {
it('should be disabled via an env variable', () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
expect(sut.isEnabled()).toBe(false);
});
@ -82,7 +94,7 @@ describe(SearchService.name, () => {
it('should return the config when search is disabled', () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
expect(sut.getConfig()).toEqual({ enabled: false });
});
@ -91,13 +103,15 @@ describe(SearchService.name, () => {
describe(`bootstrap`, () => {
it('should skip when search is disabled', async () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
await sut.bootstrap();
expect(searchMock.setup).not.toHaveBeenCalled();
expect(searchMock.checkMigrationStatus).not.toHaveBeenCalled();
expect(jobMock.queue).not.toHaveBeenCalled();
sut.teardown();
});
it('should skip schema migration if not needed', async () => {
@ -123,21 +137,18 @@ describe(SearchService.name, () => {
describe('search', () => {
it('should throw an error is search is disabled', async () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
await expect(sut.search(authStub.admin, {})).rejects.toBeInstanceOf(BadRequestException);
expect(searchMock.search).not.toHaveBeenCalled();
expect(searchMock.searchAlbums).not.toHaveBeenCalled();
expect(searchMock.searchAssets).not.toHaveBeenCalled();
});
it('should search assets and albums', async () => {
searchMock.search.mockResolvedValue({
total: 0,
count: 0,
page: 1,
items: [],
facets: [],
});
searchMock.searchAssets.mockResolvedValue(searchStub.emptyResults);
searchMock.searchAlbums.mockResolvedValue(searchStub.emptyResults);
searchMock.vectorSearch.mockResolvedValue(searchStub.emptyResults);
await expect(sut.search(authStub.admin, {})).resolves.toEqual({
albums: {
@ -156,162 +167,158 @@ describe(SearchService.name, () => {
},
});
expect(searchMock.search.mock.calls).toEqual([
['assets', '*', { userId: authStub.admin.id }],
['albums', '*', { userId: authStub.admin.id }],
]);
// expect(searchMock.searchAssets).toHaveBeenCalledWith('*', { userId: authStub.admin.id });
expect(searchMock.searchAlbums).toHaveBeenCalledWith('*', { userId: authStub.admin.id });
});
});
describe('handleIndexAssets', () => {
it('should skip if search is disabled', async () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
await sut.handleIndexAssets();
expect(searchMock.import).not.toHaveBeenCalled();
});
it('should index all the assets', async () => {
assetMock.getAll.mockResolvedValue([]);
assetMock.getAll.mockResolvedValue([assetEntityStub.image]);
await sut.handleIndexAssets();
expect(searchMock.import).toHaveBeenCalledWith('assets', [], true);
expect(searchMock.importAssets).toHaveBeenCalledWith([assetEntityStub.image], true);
});
it('should log an error', async () => {
assetMock.getAll.mockResolvedValue([]);
searchMock.import.mockRejectedValue(new Error('import failed'));
assetMock.getAll.mockResolvedValue([assetEntityStub.image]);
searchMock.importAssets.mockRejectedValue(new Error('import failed'));
await sut.handleIndexAssets();
expect(searchMock.importAssets).toHaveBeenCalled();
});
it('should skip if search is disabled', async () => {
configMock.get.mockReturnValue('false');
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
await sut.handleIndexAssets();
expect(searchMock.importAssets).not.toHaveBeenCalled();
expect(searchMock.importAlbums).not.toHaveBeenCalled();
});
});
describe('handleIndexAsset', () => {
it('should skip if search is disabled', async () => {
it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
await sut.handleIndexAsset({ asset: assetEntityStub.image });
expect(searchMock.index).not.toHaveBeenCalled();
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleIndexAsset({ ids: [assetEntityStub.image.id] });
});
it('should index the asset', async () => {
await sut.handleIndexAsset({ asset: assetEntityStub.image });
expect(searchMock.index).toHaveBeenCalledWith('assets', assetEntityStub.image);
});
it('should log an error', async () => {
searchMock.index.mockRejectedValue(new Error('index failed'));
await sut.handleIndexAsset({ asset: assetEntityStub.image });
expect(searchMock.index).toHaveBeenCalled();
it('should index the asset', () => {
sut.handleIndexAsset({ ids: [assetEntityStub.image.id] });
});
});
describe('handleIndexAlbums', () => {
it('should skip if search is disabled', async () => {
it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
await sut.handleIndexAlbums();
expect(searchMock.import).not.toHaveBeenCalled();
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleIndexAlbums();
});
it('should index all the albums', async () => {
albumMock.getAll.mockResolvedValue([]);
albumMock.getAll.mockResolvedValue([albumStub.empty]);
await sut.handleIndexAlbums();
expect(searchMock.import).toHaveBeenCalledWith('albums', [], true);
expect(searchMock.importAlbums).toHaveBeenCalledWith([albumStub.empty], true);
});
it('should log an error', async () => {
albumMock.getAll.mockResolvedValue([]);
searchMock.import.mockRejectedValue(new Error('import failed'));
albumMock.getAll.mockResolvedValue([albumStub.empty]);
searchMock.importAlbums.mockRejectedValue(new Error('import failed'));
await sut.handleIndexAlbums();
expect(searchMock.importAlbums).toHaveBeenCalled();
});
});
describe('handleIndexAlbum', () => {
it('should skip if search is disabled', async () => {
it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
await sut.handleIndexAlbum({ album: albumStub.empty });
expect(searchMock.index).not.toHaveBeenCalled();
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleIndexAlbum({ ids: [albumStub.empty.id] });
});
it('should index the album', async () => {
await sut.handleIndexAlbum({ album: albumStub.empty });
expect(searchMock.index).toHaveBeenCalledWith('albums', albumStub.empty);
});
it('should log an error', async () => {
searchMock.index.mockRejectedValue(new Error('index failed'));
await sut.handleIndexAlbum({ album: albumStub.empty });
expect(searchMock.index).toHaveBeenCalled();
it('should index the album', () => {
sut.handleIndexAlbum({ ids: [albumStub.empty.id] });
});
});
describe('handleRemoveAlbum', () => {
it('should skip if search is disabled', async () => {
it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
await sut.handleRemoveAlbum({ id: 'album1' });
expect(searchMock.delete).not.toHaveBeenCalled();
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleRemoveAlbum({ ids: ['album1'] });
});
it('should remove the album', async () => {
await sut.handleRemoveAlbum({ id: 'album1' });
expect(searchMock.delete).toHaveBeenCalledWith('albums', 'album1');
});
it('should log an error', async () => {
searchMock.delete.mockRejectedValue(new Error('remove failed'));
await sut.handleRemoveAlbum({ id: 'album1' });
expect(searchMock.delete).toHaveBeenCalled();
it('should remove the album', () => {
sut.handleRemoveAlbum({ ids: ['album1'] });
});
});
describe('handleRemoveAsset', () => {
it('should skip if search is disabled', async () => {
it('should skip if search is disabled', () => {
configMock.get.mockReturnValue('false');
sut = new SearchService(albumMock, assetMock, jobMock, searchMock, configMock);
await sut.handleRemoveAsset({ id: 'asset1`' });
expect(searchMock.delete).not.toHaveBeenCalled();
const sut = new SearchService(albumMock, assetMock, jobMock, machineMock, searchMock, configMock);
sut.handleRemoveAsset({ ids: ['asset1'] });
});
it('should remove the asset', async () => {
await sut.handleRemoveAsset({ id: 'asset1' });
it('should remove the asset', () => {
sut.handleRemoveAsset({ ids: ['asset1'] });
});
});
expect(searchMock.delete).toHaveBeenCalledWith('assets', 'asset1');
describe('flush', () => {
it('should flush queued album updates', async () => {
albumMock.getByIds.mockResolvedValue([albumStub.empty]);
sut.handleIndexAlbum({ ids: ['album1'] });
jest.runOnlyPendingTimers();
await asyncTick(4);
expect(albumMock.getByIds).toHaveBeenCalledWith(['album1']);
expect(searchMock.importAlbums).toHaveBeenCalledWith([albumStub.empty], false);
});
it('should log an error', async () => {
searchMock.delete.mockRejectedValue(new Error('remove failed'));
it('should flush queued album deletes', async () => {
sut.handleRemoveAlbum({ ids: ['album1'] });
await sut.handleRemoveAsset({ id: 'asset1' });
jest.runOnlyPendingTimers();
expect(searchMock.delete).toHaveBeenCalled();
await asyncTick(4);
expect(searchMock.deleteAlbums).toHaveBeenCalledWith(['album1']);
});
it('should flush queued asset updates', async () => {
assetMock.getByIds.mockResolvedValue([assetEntityStub.image]);
sut.handleIndexAsset({ ids: ['asset1'] });
jest.runOnlyPendingTimers();
await asyncTick(4);
expect(assetMock.getByIds).toHaveBeenCalledWith(['asset1']);
expect(searchMock.importAssets).toHaveBeenCalledWith([assetEntityStub.image], false);
});
it('should flush queued asset deletes', async () => {
sut.handleRemoveAsset({ ids: ['asset1'] });
jest.runOnlyPendingTimers();
await asyncTick(4);
expect(searchMock.deleteAssets).toHaveBeenCalledWith(['asset1']);
});
});
});

View file

@ -1,27 +1,64 @@
import { AssetEntity } from '@app/infra/db/entities';
import { MACHINE_LEARNING_ENABLED } from '@app/common';
import { AlbumEntity, AssetEntity } from '@app/infra/db/entities';
import { BadRequestException, Inject, Injectable, Logger } from '@nestjs/common';
import { ConfigService } from '@nestjs/config';
import { mapAlbum } from '../album';
import { IAlbumRepository } from '../album/album.repository';
import { mapAsset } from '../asset';
import { IAssetRepository } from '../asset/asset.repository';
import { AuthUserDto } from '../auth';
import { IAlbumJob, IAssetJob, IDeleteJob, IJobRepository, JobName } from '../job';
import { IBulkEntityJob, IJobRepository, JobName } from '../job';
import { IMachineLearningRepository } from '../smart-info';
import { SearchDto } from './dto';
import { SearchConfigResponseDto, SearchResponseDto } from './response-dto';
import { ISearchRepository, SearchCollection, SearchExploreItem } from './search.repository';
import {
ISearchRepository,
SearchCollection,
SearchExploreItem,
SearchResult,
SearchStrategy,
} from './search.repository';
interface SyncQueue {
upsert: Set<string>;
delete: Set<string>;
}
@Injectable()
export class SearchService {
private logger = new Logger(SearchService.name);
private enabled: boolean;
private timer: NodeJS.Timer | null = null;
private albumQueue: SyncQueue = {
upsert: new Set(),
delete: new Set(),
};
private assetQueue: SyncQueue = {
upsert: new Set(),
delete: new Set(),
};
constructor(
@Inject(IAlbumRepository) private albumRepository: IAlbumRepository,
@Inject(IAssetRepository) private assetRepository: IAssetRepository,
@Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository,
@Inject(ISearchRepository) private searchRepository: ISearchRepository,
configService: ConfigService,
) {
this.enabled = configService.get('TYPESENSE_ENABLED') !== 'false';
if (this.enabled) {
this.timer = setInterval(() => this.flush(), 5_000);
}
}
teardown() {
if (this.timer) {
clearInterval(this.timer);
this.timer = null;
}
}
isEnabled() {
@ -61,103 +98,131 @@ export class SearchService {
async search(authUser: AuthUserDto, dto: SearchDto): Promise<SearchResponseDto> {
this.assertEnabled();
const query = dto.query || '*';
const query = dto.q || dto.query || '*';
const strategy = dto.clip ? SearchStrategy.CLIP : SearchStrategy.TEXT;
const filters = { userId: authUser.id, ...dto };
let assets: SearchResult<AssetEntity>;
switch (strategy) {
case SearchStrategy.TEXT:
assets = await this.searchRepository.searchAssets(query, filters);
break;
case SearchStrategy.CLIP:
default:
if (!MACHINE_LEARNING_ENABLED) {
throw new BadRequestException('Machine Learning is disabled');
}
const clip = await this.machineLearning.encodeText(query);
assets = await this.searchRepository.vectorSearch(clip, filters);
}
const albums = await this.searchRepository.searchAlbums(query, filters);
return {
assets: (await this.searchRepository.search(SearchCollection.ASSETS, query, {
userId: authUser.id,
...dto,
})) as any,
albums: (await this.searchRepository.search(SearchCollection.ALBUMS, query, {
userId: authUser.id,
...dto,
})) as any,
albums: { ...albums, items: albums.items.map(mapAlbum) },
assets: { ...assets, items: assets.items.map(mapAsset) },
};
}
async handleIndexAssets() {
if (!this.enabled) {
return;
}
try {
this.logger.debug(`Running indexAssets`);
// TODO: do this in batches based on searchIndexVersion
const assets = await this.assetRepository.getAll({ isVisible: true });
this.logger.log(`Indexing ${assets.length} assets`);
await this.searchRepository.import(SearchCollection.ASSETS, assets, true);
this.logger.debug('Finished re-indexing all assets');
} catch (error: any) {
this.logger.error(`Unable to index all assets`, error?.stack);
}
}
async handleIndexAsset(data: IAssetJob) {
if (!this.enabled) {
return;
}
const { asset } = data;
if (!asset.isVisible) {
return;
}
try {
await this.searchRepository.index(SearchCollection.ASSETS, asset);
} catch (error: any) {
this.logger.error(`Unable to index asset: ${asset.id}`, error?.stack);
}
}
async handleIndexAlbums() {
if (!this.enabled) {
return;
}
try {
const albums = await this.albumRepository.getAll();
const albums = this.patchAlbums(await this.albumRepository.getAll());
this.logger.log(`Indexing ${albums.length} albums`);
await this.searchRepository.import(SearchCollection.ALBUMS, albums, true);
this.logger.debug('Finished re-indexing all albums');
await this.searchRepository.importAlbums(albums, true);
} catch (error: any) {
this.logger.error(`Unable to index all albums`, error?.stack);
}
}
async handleIndexAlbum(data: IAlbumJob) {
async handleIndexAssets() {
if (!this.enabled) {
return;
}
const { album } = data;
try {
await this.searchRepository.index(SearchCollection.ALBUMS, album);
// TODO: do this in batches based on searchIndexVersion
const assets = this.patchAssets(await this.assetRepository.getAll({ isVisible: true }));
this.logger.log(`Indexing ${assets.length} assets`);
await this.searchRepository.importAssets(assets, true);
this.logger.debug('Finished re-indexing all assets');
} catch (error: any) {
this.logger.error(`Unable to index album: ${album.id}`, error?.stack);
this.logger.error(`Unable to index all assets`, error?.stack);
}
}
async handleRemoveAlbum(data: IDeleteJob) {
await this.handleRemove(SearchCollection.ALBUMS, data);
}
async handleRemoveAsset(data: IDeleteJob) {
await this.handleRemove(SearchCollection.ASSETS, data);
}
private async handleRemove(collection: SearchCollection, data: IDeleteJob) {
handleIndexAlbum({ ids }: IBulkEntityJob) {
if (!this.enabled) {
return;
}
const { id } = data;
for (const id of ids) {
this.albumQueue.upsert.add(id);
}
}
try {
await this.searchRepository.delete(collection, id);
} catch (error: any) {
this.logger.error(`Unable to remove ${collection}: ${id}`, error?.stack);
handleIndexAsset({ ids }: IBulkEntityJob) {
if (!this.enabled) {
return;
}
for (const id of ids) {
this.assetQueue.upsert.add(id);
}
}
handleRemoveAlbum({ ids }: IBulkEntityJob) {
if (!this.enabled) {
return;
}
for (const id of ids) {
this.albumQueue.delete.add(id);
}
}
handleRemoveAsset({ ids }: IBulkEntityJob) {
if (!this.enabled) {
return;
}
for (const id of ids) {
this.assetQueue.delete.add(id);
}
}
private async flush() {
if (this.albumQueue.upsert.size > 0) {
const ids = [...this.albumQueue.upsert.keys()];
const items = await this.idsToAlbums(ids);
this.logger.debug(`Flushing ${items.length} album upserts`);
await this.searchRepository.importAlbums(items, false);
this.albumQueue.upsert.clear();
}
if (this.albumQueue.delete.size > 0) {
const ids = [...this.albumQueue.delete.keys()];
this.logger.debug(`Flushing ${ids.length} album deletes`);
await this.searchRepository.deleteAlbums(ids);
this.albumQueue.delete.clear();
}
if (this.assetQueue.upsert.size > 0) {
const ids = [...this.assetQueue.upsert.keys()];
const items = await this.idsToAssets(ids);
this.logger.debug(`Flushing ${items.length} asset upserts`);
await this.searchRepository.importAssets(items, false);
this.assetQueue.upsert.clear();
}
if (this.assetQueue.delete.size > 0) {
const ids = [...this.assetQueue.delete.keys()];
this.logger.debug(`Flushing ${ids.length} asset deletes`);
await this.searchRepository.deleteAssets(ids);
this.assetQueue.delete.clear();
}
}
@ -166,4 +231,22 @@ export class SearchService {
throw new BadRequestException('Search is disabled');
}
}
private async idsToAlbums(ids: string[]): Promise<AlbumEntity[]> {
const entities = await this.albumRepository.getByIds(ids);
return this.patchAlbums(entities);
}
private async idsToAssets(ids: string[]): Promise<AssetEntity[]> {
const entities = await this.assetRepository.getByIds(ids);
return this.patchAssets(entities.filter((entity) => entity.isVisible));
}
private patchAssets(assets: AssetEntity[]): AssetEntity[] {
return assets;
}
private patchAlbums(albums: AlbumEntity[]): AlbumEntity[] {
return albums.map((entity) => ({ ...entity, assets: [] }));
}
}

View file

@ -7,4 +7,6 @@ export interface MachineLearningInput {
export interface IMachineLearningRepository {
tagImage(input: MachineLearningInput): Promise<string[]>;
detectObjects(input: MachineLearningInput): Promise<string[]>;
encodeImage(input: MachineLearningInput): Promise<number[]>;
encodeText(input: string): Promise<number[]>;
}

View file

@ -1,5 +1,6 @@
import { AssetEntity } from '@app/infra/db/entities';
import { newMachineLearningRepositoryMock, newSmartInfoRepositoryMock } from '../../test';
import { newJobRepositoryMock, newMachineLearningRepositoryMock, newSmartInfoRepositoryMock } from '../../test';
import { IJobRepository } from '../job';
import { IMachineLearningRepository } from './machine-learning.interface';
import { ISmartInfoRepository } from './smart-info.repository';
import { SmartInfoService } from './smart-info.service';
@ -11,13 +12,15 @@ const asset = {
describe(SmartInfoService.name, () => {
let sut: SmartInfoService;
let jobMock: jest.Mocked<IJobRepository>;
let smartMock: jest.Mocked<ISmartInfoRepository>;
let machineMock: jest.Mocked<IMachineLearningRepository>;
beforeEach(async () => {
smartMock = newSmartInfoRepositoryMock();
jobMock = newJobRepositoryMock();
machineMock = newMachineLearningRepositoryMock();
sut = new SmartInfoService(smartMock, machineMock);
sut = new SmartInfoService(jobMock, smartMock, machineMock);
});
it('should work', () => {

View file

@ -1,6 +1,6 @@
import { MACHINE_LEARNING_ENABLED } from '@app/common';
import { Inject, Injectable, Logger } from '@nestjs/common';
import { IAssetJob } from '../job';
import { IAssetJob, IJobRepository, JobName } from '../job';
import { IMachineLearningRepository } from './machine-learning.interface';
import { ISmartInfoRepository } from './smart-info.repository';
@ -9,6 +9,7 @@ export class SmartInfoService {
private logger = new Logger(SmartInfoService.name);
constructor(
@Inject(IJobRepository) private jobRepository: IJobRepository,
@Inject(ISmartInfoRepository) private repository: ISmartInfoRepository,
@Inject(IMachineLearningRepository) private machineLearning: IMachineLearningRepository,
) {}
@ -24,6 +25,7 @@ export class SmartInfoService {
const tags = await this.machineLearning.tagImage({ thumbnailPath: asset.resizePath });
if (tags.length > 0) {
await this.repository.upsert({ assetId: asset.id, tags });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } });
}
} catch (error: any) {
this.logger.error(`Unable to run image tagging pipeline: ${asset.id}`, error?.stack);
@ -41,9 +43,26 @@ export class SmartInfoService {
const objects = await this.machineLearning.detectObjects({ thumbnailPath: asset.resizePath });
if (objects.length > 0) {
await this.repository.upsert({ assetId: asset.id, objects });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } });
}
} catch (error: any) {
this.logger.error(`Unable run object detection pipeline: ${asset.id}`, error?.stack);
}
}
async handleEncodeClip(data: IAssetJob) {
const { asset } = data;
if (!MACHINE_LEARNING_ENABLED || !asset.resizePath) {
return;
}
try {
const clipEmbedding = await this.machineLearning.encodeImage({ thumbnailPath: asset.resizePath });
await this.repository.upsert({ assetId: asset.id, clipEmbedding: clipEmbedding });
await this.jobRepository.queue({ name: JobName.SEARCH_INDEX_ASSET, data: { ids: [asset.id] } });
} catch (error: any) {
this.logger.error(`Unable run clip encoding pipeline: ${asset.id}`, error?.stack);
}
}
}

View file

@ -2,6 +2,7 @@ import { IAlbumRepository } from '../src';
export const newAlbumRepositoryMock = (): jest.Mocked<IAlbumRepository> => {
return {
getByIds: jest.fn(),
deleteAll: jest.fn(),
getAll: jest.fn(),
save: jest.fn(),

View file

@ -2,6 +2,7 @@ import { IAssetRepository } from '../src';
export const newAssetRepositoryMock = (): jest.Mocked<IAssetRepository> => {
return {
getByIds: jest.fn(),
getAll: jest.fn(),
deleteAll: jest.fn(),
save: jest.fn(),

View file

@ -15,6 +15,7 @@ import {
AuthUserDto,
ExifResponseDto,
mapUser,
SearchResult,
SharedLinkResponseDto,
} from '../src';
@ -448,6 +449,7 @@ export const sharedLinkStub = {
tags: [],
objects: ['a', 'b', 'c'],
asset: null as any,
clipEmbedding: [0.12, 0.13, 0.14],
},
webpPath: '',
encodedVideoPath: '',
@ -550,3 +552,13 @@ export const sharedLinkResponseStub = {
// TODO - the constructor isn't used anywhere, so not test coverage
new ExifResponseDto();
export const searchStub = {
emptyResults: Object.freeze<SearchResult<any>>({
total: 0,
count: 0,
page: 1,
items: [],
facets: [],
}),
};

View file

@ -13,3 +13,9 @@ export * from './storage.repository.mock';
export * from './system-config.repository.mock';
export * from './user-token.repository.mock';
export * from './user.repository.mock';
export async function asyncTick(steps: number) {
for (let i = 0; i < steps; i++) {
await Promise.resolve();
}
}

View file

@ -4,5 +4,7 @@ export const newMachineLearningRepositoryMock = (): jest.Mocked<IMachineLearning
return {
tagImage: jest.fn(),
detectObjects: jest.fn(),
encodeImage: jest.fn(),
encodeText: jest.fn(),
};
};

View file

@ -4,10 +4,13 @@ export const newSearchRepositoryMock = (): jest.Mocked<ISearchRepository> => {
return {
setup: jest.fn(),
checkMigrationStatus: jest.fn(),
index: jest.fn(),
import: jest.fn(),
search: jest.fn(),
delete: jest.fn(),
importAssets: jest.fn(),
importAlbums: jest.fn(),
deleteAlbums: jest.fn(),
deleteAssets: jest.fn(),
searchAssets: jest.fn(),
searchAlbums: jest.fn(),
vectorSearch: jest.fn(),
explore: jest.fn(),
};
};