Skip to content

[DT-1361]SnapshotCreateFlight: Create new steps to replace AuthorizeBillingProfileUseStep #1958

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 1 addition & 10 deletions src/main/java/bio/terra/service/dataset/DatasetService.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,16 +165,7 @@ public String createDataset(
String description = "Create dataset " + datasetRequest.getName();
UUID defaultProfileId = datasetRequest.getDefaultProfileId();
loggingMetrics.set(BardEventProperties.BILLING_PROFILE_ID_FIELD_NAME, defaultProfileId);

// Locate billing profile in TDR or Rawls
// No auth check: Just a check if there is an entry in our db for this billing profile
boolean isTdrBillingProfile;
try {
profileService.getProfileByIdNoCheck(defaultProfileId);
isTdrBillingProfile = true;
} catch (ProfileNotFoundException ex) {
isTdrBillingProfile = false;
}
boolean isTdrBillingProfile = profileService.isTdrBillingProfile(defaultProfileId);

return jobService
.newJob(description, DatasetCreateFlight.class, datasetRequest, userReq)
Expand Down
13 changes: 13 additions & 0 deletions src/main/java/bio/terra/service/profile/ProfileService.java
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,19 @@ public BillingProfileModel getProfileByIdNoCheck(UUID id) {
return profileDao.getBillingProfileById(id);
}

public boolean isTdrBillingProfile(UUID defaultProfileId) {
// Locate billing profile in TDR or Rawls
// No auth check: Just a check if there is an entry in our db for this billing profile
boolean isTdrBillingProfile;
try {
getProfileByIdNoCheck(defaultProfileId);
isTdrBillingProfile = true;
Copy link
Member

@pshapiro4broad pshapiro4broad Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Up to you, but now that this is a function you can use return here and in the catch and remove isTdrBillingProfile from the function.

} catch (ProfileNotFoundException ex) {
isTdrBillingProfile = false;
}
return isTdrBillingProfile;
}

// The idea is to use this call from create snapshot and create asset to validate that the
// billing account is usable by the calling user

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
import bio.terra.service.filedata.google.firestore.FireStoreDependencyDao;
import bio.terra.service.job.JobMapKeys;
import bio.terra.service.job.JobService;
import bio.terra.service.profile.ProfileService;
import bio.terra.service.rawls.RawlsService;
import bio.terra.service.resourcemanagement.MetadataDataAccessUtils;
import bio.terra.service.snapshot.exception.AssetNotFoundException;
Expand Down Expand Up @@ -144,6 +145,7 @@ public class SnapshotService {
private final RawlsService rawlsService;
private final DuosClient duosClient;
private final SnapshotBuilderSettingsDao snapshotBuilderSettingsDao;
private final ProfileService profileService;

public SnapshotService(
JobService jobService,
Expand All @@ -159,7 +161,8 @@ public SnapshotService(
AzureSynapsePdao azureSynapsePdao,
RawlsService rawlsService,
DuosClient duosClient,
SnapshotBuilderSettingsDao snapshotBuilderSettingsDao) {
SnapshotBuilderSettingsDao snapshotBuilderSettingsDao,
ProfileService profileService) {
this.jobService = jobService;
this.datasetService = datasetService;
this.dependencyDao = dependencyDao;
Expand All @@ -174,6 +177,7 @@ public SnapshotService(
this.rawlsService = rawlsService;
this.duosClient = duosClient;
this.snapshotBuilderSettingsDao = snapshotBuilderSettingsDao;
this.profileService = profileService;
}

public String getSnapshotName(SnapshotRequestModel model) {
Expand Down Expand Up @@ -248,11 +252,15 @@ public String createSnapshot(
.map(model -> model.datasetName(dataset.getName()))
.toList());

boolean isTdrBillingProfile =
profileService.isTdrBillingProfile(snapshotRequestModel.getProfileId());

return jobService
.newJob(description, SnapshotCreateFlight.class, snapshotRequestModel, userReq)
.addParameter(CommonMapKeys.CREATED_AT, Instant.now().toEpochMilli())
.addParameter(JobMapKeys.DATASET_ID.getKeyName(), dataset.getId())
.addParameter(JobMapKeys.SNAPSHOT_ID.getKeyName(), snapshotId.toString())
.addParameter(JobMapKeys.TDR_BILLING_PROFILE_FALLBACK.getKeyName(), isTdrBillingProfile)
.submit();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
import bio.terra.service.policy.PolicyService;
import bio.terra.service.profile.ProfileService;
import bio.terra.service.profile.flight.AuthorizeBillingProfileUseStep;
import bio.terra.service.profile.flight.AuthorizeRawlsBillingProjectsUseStep;
import bio.terra.service.profile.flight.VerifyBillingAccountAccessStep;
import bio.terra.service.profile.google.GoogleBillingService;
import bio.terra.service.rawls.RawlsService;
import bio.terra.service.resourcemanagement.BufferService;
import bio.terra.service.resourcemanagement.ResourceService;
import bio.terra.service.resourcemanagement.azure.AzureAuthService;
Expand Down Expand Up @@ -109,6 +111,7 @@ public SnapshotCreateFlight(FlightMap inputParameters, Object applicationContext
SnapshotBuilderService snapshotBuilderService =
appContext.getBean(SnapshotBuilderService.class);
IamService iamService = appContext.getBean(IamService.class);
RawlsService rawlsService = appContext.getBean(RawlsService.class);

SnapshotRequestModel snapshotReq =
inputParameters.get(JobMapKeys.REQUEST.getKeyName(), SnapshotRequestModel.class);
Expand Down Expand Up @@ -140,13 +143,25 @@ public SnapshotCreateFlight(FlightMap inputParameters, Object applicationContext
var platform =
CloudPlatformWrapper.of(sourceDataset.getDatasetSummary().getStorageCloudPlatform());

boolean isTdrBillingProfile =
inputParameters.get(JobMapKeys.TDR_BILLING_PROFILE_FALLBACK.getKeyName(), Boolean.class);

// Take out a shared lock on the source dataset, to guard against it being deleted out from
// under us (for example)
addStep(new LockDatasetStep(datasetService, datasetId, true), randomBackoffRetry);

// Make sure this user is authorized to use the billing profile in SAM
addStep(
new AuthorizeBillingProfileUseStep(profileService, snapshotReq.getProfileId(), userReq));
if (isTdrBillingProfile) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This now duplicates the code in DatasetCreateFlight. If we wanted to avoid this duplication, one approach would be to put both cases inside AuthorizeBillingProfileUseStep and have it also accept rawlsService and isTdrBillingProfile.

Another benefit (or shortcoming, depending on how you look at it) of this refactoring is that changing AuthorizeBillingProfileUseStep means that we'd have to look at all cases where this step is used and fix them as part of this work.

// If using TDR billing profile, make sure this user is authorized to use the billing profile
// in Sam
addStep(
new AuthorizeBillingProfileUseStep(profileService, snapshotReq.getProfileId(), userReq));
} else {
// If using Rawls billing project, make sure this user is authorized to use the billing
// project in Sam
addStep(
new AuthorizeRawlsBillingProjectsUseStep(
rawlsService, snapshotReq.getProfileId(), userReq));
}

if (platform.isGcp()) {
addStep(new VerifyBillingAccountAccessStep(googleBillingService));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import bio.terra.service.load.LoadService;
import bio.terra.service.profile.ProfileDao;
import bio.terra.service.profile.ProfileService;
import bio.terra.service.profile.exception.ProfileNotFoundException;
import bio.terra.service.resourcemanagement.MetadataDataAccessUtils;
import bio.terra.service.resourcemanagement.ResourceService;
import bio.terra.service.tabulardata.azure.StorageTableService;
Expand Down Expand Up @@ -486,17 +485,12 @@ void createDataset(boolean isTDRBillingProfile) {
datasetRequestModel,
TEST_USER))
.thenReturn(jobBuilder);

if (!isTDRBillingProfile) {
when(profileService.getProfileByIdNoCheck(defaultBillingProfile))
.thenThrow(new ProfileNotFoundException("Profile not found"));
}
when(profileService.isTdrBillingProfile(defaultBillingProfile)).thenReturn(isTDRBillingProfile);

ArgumentCaptor<FlightMap> captor = ArgumentCaptor.forClass(FlightMap.class);
when(jobService.submit(eq(DatasetCreateFlight.class), captor.capture())).thenReturn("JobId");

datasetService.createDataset(datasetRequestModel, TEST_USER);
verify(profileService).getProfileByIdNoCheck(defaultBillingProfile);

FlightMap flightMap = captor.getValue();
assertThat(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
package bio.terra.service.profile;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.mock;
Expand All @@ -24,6 +26,7 @@
import bio.terra.service.job.JobMapKeys;
import bio.terra.service.job.JobService;
import bio.terra.service.profile.azure.AzureAuthzService;
import bio.terra.service.profile.exception.ProfileNotFoundException;
import bio.terra.service.profile.flight.ProfileMapKeys;
import bio.terra.service.profile.flight.create.ProfileCreateFlight;
import bio.terra.service.profile.flight.delete.ProfileDeleteFlight;
Expand Down Expand Up @@ -191,4 +194,17 @@ void getProfileResources() {
PROFILE_ID.toString(),
IamAction.READ_SPEND_REPORT);
}

@Test
void isTdrBillingProfileIdTrue() {
when(profileDao.getBillingProfileById(PROFILE_ID)).thenReturn(new BillingProfileModel());
assertTrue(profileService.isTdrBillingProfile(PROFILE_ID));
}

@Test
void isTdrBillingProfileIdFalse() {
when(profileDao.getBillingProfileById(PROFILE_ID))
.thenThrow(new ProfileNotFoundException("Profile not found"));
assertFalse(profileService.isTdrBillingProfile(PROFILE_ID));
}
}
95 changes: 54 additions & 41 deletions src/test/java/bio/terra/service/snapshot/SnapshotServiceTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
import bio.terra.service.job.JobBuilder;
import bio.terra.service.job.JobMapKeys;
import bio.terra.service.job.JobService;
import bio.terra.service.profile.ProfileService;
import bio.terra.service.rawls.RawlsService;
import bio.terra.service.resourcemanagement.MetadataDataAccessUtils;
import bio.terra.service.resourcemanagement.google.GoogleProjectResource;
Expand Down Expand Up @@ -115,6 +116,7 @@
import java.util.Map;
import java.util.Set;
import java.util.UUID;
import org.jetbrains.annotations.NotNull;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Tag;
import org.junit.jupiter.api.Test;
Expand Down Expand Up @@ -165,6 +167,7 @@ class SnapshotServiceTest {
@Mock private RawlsService rawlsService;
@Mock private DuosClient duosClient;
@Mock private SnapshotBuilderSettingsDao settingsDao;
@Mock private ProfileService profileService;
private final UUID snapshotId = UUID.randomUUID();
private final UUID datasetId = UUID.randomUUID();
private final UUID snapshotTableId = UUID.randomUUID();
Expand Down Expand Up @@ -192,7 +195,8 @@ void beforeEach() {
azureSynapsePdao,
rawlsService,
duosClient,
settingsDao);
settingsDao,
profileService);
}

@Test
Expand Down Expand Up @@ -1085,42 +1089,15 @@ private SnapshotRequestModel getDuosSnapshotRequestModel(String duosId) {
@Test
void testCreateSnapshotWithoutDuosDataset() {
SnapshotRequestModel request = getDuosSnapshotRequestModel(null);
JobBuilder jobBuilder = mock(JobBuilder.class);
when(jobService.newJob(anyString(), eq(SnapshotCreateFlight.class), eq(request), eq(TEST_USER)))
.thenReturn(jobBuilder);
when(jobBuilder.addParameter(any(), any())).thenReturn(jobBuilder);
String jobId = String.valueOf(UUID.randomUUID());
when(jobBuilder.submit()).thenReturn(jobId);

String result =
service.createSnapshot(
request, service.getSourceDatasetFromSnapshotRequest(request), TEST_USER);
assertThat("Job is submitted and id returned", result, equalTo(jobId));
mockCreateSnapshot(request);
verify(duosClient, never()).getDataset(DUOS_ID, TEST_USER);
verify(jobBuilder).submit();
}

@Test
void testCreateSnapshotWithDuosDataset() {
SnapshotRequestModel request = getDuosSnapshotRequestModel(DUOS_ID);
JobBuilder jobBuilder = mock(JobBuilder.class);
String jobId = mockJobService(request, jobBuilder);

String result =
service.createSnapshot(
request, service.getSourceDatasetFromSnapshotRequest(request), TEST_USER);
assertThat("Job is submitted and id returned", result, equalTo(jobId));
mockCreateSnapshot(request);
verify(duosClient).getDataset(DUOS_ID, TEST_USER);
verify(jobBuilder).submit();
}

private String mockJobService(SnapshotRequestModel request, JobBuilder jobBuilder) {
when(jobService.newJob(anyString(), eq(SnapshotCreateFlight.class), eq(request), eq(TEST_USER)))
.thenReturn(jobBuilder);
when(jobBuilder.addParameter(any(), any())).thenReturn(jobBuilder);
String jobId = String.valueOf(UUID.randomUUID());
when(jobBuilder.submit()).thenReturn(jobId);
return jobId;
}

@Test
Expand All @@ -1144,19 +1121,13 @@ void testCreateSnapshotWithByRequestId() {
makeByRequestIdContentsModel(snapshotAccessRequestId);
SnapshotRequestModel request = new SnapshotRequestModel().contents(List.of(contentsModel));
request.profileId(UUID.randomUUID());
JobBuilder jobBuilder = mock(JobBuilder.class);
String jobId = mockJobService(request, jobBuilder);
when(snapshotRequestDao.getById(snapshotAccessRequestId)).thenReturn(snapshotAccessRequest);
when(snapshotDao.retrieveSnapshot(snapshotAccessRequest.sourceSnapshotId()))
.thenReturn(
new Snapshot()
.snapshotSources(
List.of(new SnapshotSource().dataset(new Dataset().id(UUID.randomUUID())))));

String result =
service.createSnapshot(
request, service.getSourceDatasetFromSnapshotRequest(request), TEST_USER);
assertThat("Job is submitted and id returned", result, equalTo(jobId));
mockCreateSnapshot(request);
}

private void mockSnapshotWithDuosDataset() {
Expand Down Expand Up @@ -1201,7 +1172,6 @@ void testCreateSnapshotWithAuthDomainAndInheritStewardDisabled() {
.profileId(UUID.randomUUID())
.addDataAccessControlGroupsItem("AuthDomain1")
.contents(List.of(new SnapshotRequestContentsModel().datasetName(sourceDatasetName)));

mockCreateSnapshot(request);
}

Expand All @@ -1219,17 +1189,60 @@ void testCreateSnapshotInheritStewardEnabled() {
.name("TestSnapshot")
.profileId(UUID.randomUUID())
.contents(List.of(new SnapshotRequestContentsModel().datasetName(sourceDatasetName)));

mockCreateSnapshot(request);
}

private void mockCreateSnapshot(SnapshotRequestModel request) {
JobBuilder jobBuilder = mock(JobBuilder.class);
@ParameterizedTest
@ValueSource(booleans = {true, false})
void createSnapshot(boolean isTDRBillingProfile) {
UUID profileId = UUID.randomUUID();
UUID datasetId = UUID.randomUUID();
String jobId = UUID.randomUUID().toString();
SnapshotRequestModel snapshotRequestModel =
new SnapshotRequestModel()
.name("TestSnapshot")
.profileId(profileId)
.contents(List.of(new SnapshotRequestContentsModel().datasetName("TestSourceDataset")));
Dataset sourceDataset = new Dataset(new DatasetSummary().id(datasetId));

when(profileService.isTdrBillingProfile(profileId)).thenReturn(isTDRBillingProfile);
ArgumentCaptor<FlightMap> captor = mockJobServiceWithCaptor(snapshotRequestModel, jobId);

assertThat(
service.createSnapshot(snapshotRequestModel, sourceDataset, TEST_USER), equalTo(jobId));
FlightMap flightMap = captor.getValue();
assertThat(flightMap.get(JobMapKeys.DATASET_ID.getKeyName(), UUID.class), equalTo(datasetId));
assertThat(
flightMap.get(JobMapKeys.TDR_BILLING_PROFILE_FALLBACK.getKeyName(), Boolean.class),
equalTo(isTDRBillingProfile));
}

@NotNull
private ArgumentCaptor<FlightMap> mockJobServiceWithCaptor(
SnapshotRequestModel snapshotRequestModel, String jobId) {
JobBuilder jobBuilder =
new JobBuilder("", SnapshotCreateFlight.class, snapshotRequestModel, TEST_USER, jobService);
when(jobService.newJob(
anyString(), eq(SnapshotCreateFlight.class), eq(snapshotRequestModel), eq(TEST_USER)))
.thenReturn(jobBuilder);
ArgumentCaptor<FlightMap> captor = ArgumentCaptor.forClass(FlightMap.class);
when(jobService.submit(eq(SnapshotCreateFlight.class), captor.capture())).thenReturn(jobId);
return captor;
}

private String mockJobService(SnapshotRequestModel request, JobBuilder jobBuilder) {
when(jobService.newJob(anyString(), eq(SnapshotCreateFlight.class), eq(request), eq(TEST_USER)))
.thenReturn(jobBuilder);
when(jobBuilder.addParameter(any(), any())).thenReturn(jobBuilder);
String jobId = String.valueOf(UUID.randomUUID());
when(jobBuilder.submit()).thenReturn(jobId);
return jobId;
}

private void mockCreateSnapshot(SnapshotRequestModel request) {
JobBuilder jobBuilder = mock(JobBuilder.class);
String jobId = mockJobService(request, jobBuilder);
when(profileService.isTdrBillingProfile(request.getProfileId())).thenReturn(false);

String result =
service.createSnapshot(
Expand Down
Loading